summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py37
1 files changed, 20 insertions, 17 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 572f411..4e404b8 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,8 +1,9 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
+from sqlglot.helper import seq_get
def _create_sql(self, e):
@@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS,
+ **Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Literal.number(1),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Sub(
- this=exp.Length(this=list_get(args, 0)),
- expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
+ this=exp.Length(this=seq_get(args, 0)),
+ expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING,
+ **Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
- **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
+ **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
}
+ TRANSFORMS.pop(exp.ArraySort)
+ TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False