diff options
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r-- | sqlglot/dialects/spark.py | 240 |
1 files changed, 36 insertions, 204 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index a3e4cce..939f2fd 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -2,222 +2,54 @@ from __future__ import annotations import typing as t -from sqlglot import exp, parser -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql -from sqlglot.dialects.hive import Hive +from sqlglot import exp +from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get -def _create_sql(self: Hive.Generator, e: exp.Create) -> str: - kind = e.args["kind"] - properties = e.args.get("properties") +def _parse_datediff(args: t.Sequence) -> exp.Expression: + """ + Although Spark docs don't mention the "unit" argument, Spark3 added support for + it at some point. Databricks also supports this variation (see below). - if kind.upper() == "TABLE" and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ): - return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return create_with_partitions_sql(self, e) + For example, in spark-sql (v3.3.1): + - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4 + - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4 + See also: + - https://docs.databricks.com/sql/language-manual/functions/datediff3.html + - https://docs.databricks.com/sql/language-manual/functions/datediff.html + """ + unit = None + this = seq_get(args, 0) + expression = seq_get(args, 1) -def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: - keys = self.sql(expression.args["keys"]) - values = self.sql(expression.args["values"]) - return f"MAP_FROM_ARRAYS({keys}, {values})" + if len(args) == 3: + unit = this + this = args[2] + return exp.DateDiff( + this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit + ) -def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - if time_format == Hive.date_format: - return f"TO_DATE({this})" - return f"TO_DATE({this}, {time_format})" - -def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale is None: - return f"FROM_UNIXTIME({timestamp})" - if scale == exp.UnixToTime.SECONDS: - return f"TIMESTAMP_SECONDS({timestamp})" - if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" - if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" - - raise ValueError("Improper scale for timestamp") - - -class Spark(Hive): - class Parser(Hive.Parser): +class Spark(Spark2): + class Parser(Spark2.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, # type: ignore - "MAP_FROM_ARRAYS": exp.Map.from_arg_list, - "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, - "LEFT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Literal.number(1), - length=seq_get(args, 1), - ), - "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "RIGHT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Sub( - this=exp.Length(this=seq_get(args, 0)), - expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), - ), - length=seq_get(args, 1), - ), - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "BOOLEAN": lambda args: exp.Cast( - this=seq_get(args, 0), to=exp.DataType.build("boolean") - ), - "IIF": exp.If.from_arg_list, - "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")), - "AGGREGATE": exp.Reduce.from_arg_list, - "DAYOFWEEK": lambda args: exp.DayOfWeek( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFMONTH": lambda args: exp.DayOfMonth( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFYEAR": lambda args: exp.DayOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "WEEKOFYEAR": lambda args: exp.WeekOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), - unit=exp.var(seq_get(args, 0)), - ), - "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")), - "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), - "TIMESTAMP": lambda args: exp.Cast( - this=seq_get(args, 0), to=exp.DataType.build("timestamp") - ), - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore - "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), - "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), - "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), - "MERGE": lambda self: self._parse_join_hint("MERGE"), - "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), - "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), - "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), - "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), - } - - def _parse_add_column(self) -> t.Optional[exp.Expression]: - return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - - def _parse_drop_column(self) -> t.Optional[exp.Expression]: - return self._match_text_seq("DROP", "COLUMNS") and self.expression( - exp.Drop, - this=self._parse_schema(), - kind="COLUMNS", - ) - - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - # Spark doesn't add a suffix to the pivot columns when there's a single aggregation - if len(pivot_columns) == 1: - return [""] - - names = [] - for agg in pivot_columns: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - - Moreover, function names are lowercased in order to mimic Spark's naming scheme. - """ - agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) - - return names - - class Generator(Hive.Generator): - TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.TINYINT: "BYTE", - exp.DataType.Type.SMALLINT: "SHORT", - exp.DataType.Type.BIGINT: "LONG", - } - - PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, # type: ignore - exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, - exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, - exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, - exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, - } - - TRANSFORMS = { - **Hive.Generator.TRANSFORMS, # type: ignore - exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", - exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), - exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), - exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", - exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, - exp.Create: _create_sql, - exp.Map: _map_sql, - exp.Reduce: rename_func("AGGREGATE"), - exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), - exp.Trim: trim_sql, - exp.VariancePop: rename_func("VAR_POP"), - exp.DateFromParts: rename_func("MAKE_DATE"), - exp.LogicalOr: rename_func("BOOL_OR"), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + **Spark2.Parser.FUNCTIONS, # type: ignore + "DATEDIFF": _parse_datediff, } - TRANSFORMS.pop(exp.ArraySort) - TRANSFORMS.pop(exp.ILike) - WRAP_DERIVED_VALUES = False - CREATE_FUNCTION_RETURN_AS = False + class Generator(Spark2.Generator): + TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() + TRANSFORMS.pop(exp.DateDiff) - def cast_sql(self, expression: exp.Cast) -> str: - if isinstance(expression.this, exp.Cast) and expression.this.is_type( - exp.DataType.Type.JSON - ): - schema = f"'{self.sql(expression, 'to')}'" - return self.func("FROM_JSON", expression.this.this, schema) - if expression.to.is_type(exp.DataType.Type.JSON): - return self.func("TO_JSON", expression.this) + def datediff_sql(self, expression: exp.DateDiff) -> str: + unit = self.sql(expression, "unit") + end = self.sql(expression, "this") + start = self.sql(expression, "expression") - return super(Spark.Generator, self).cast_sql(expression) + if unit: + return self.func("DATEDIFF", unit, start, end) - class Tokenizer(Hive.Tokenizer): - HEX_STRINGS = [("X'", "'")] + return self.func("DATEDIFF", end, start) |