From beba715b97dd2349e01dde9b077d2535680ebdca Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 10 May 2023 08:44:58 +0200 Subject: Merging upstream version 12.2.0. Signed-off-by: Daniel Baumann --- docs/sqlglot/dialects/spark.html | 740 +++++++-------------------------------- 1 file changed, 122 insertions(+), 618 deletions(-) (limited to 'docs/sqlglot/dialects/spark.html') diff --git a/docs/sqlglot/dialects/spark.html b/docs/sqlglot/dialects/spark.html index bc53a56..d29015b 100644 --- a/docs/sqlglot/dialects/spark.html +++ b/docs/sqlglot/dialects/spark.html @@ -43,17 +43,11 @@ Spark.Generator -
  • - Spark.Tokenizer -
      -
    - -
  • @@ -80,229 +74,61 @@ -
      1from __future__ import annotations
    -  2
    -  3import typing as t
    -  4
    -  5from sqlglot import exp, parser
    -  6from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
    -  7from sqlglot.dialects.hive import Hive
    -  8from sqlglot.helper import seq_get
    -  9
    - 10
    - 11def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
    - 12    kind = e.args["kind"]
    - 13    properties = e.args.get("properties")
    - 14
    - 15    if kind.upper() == "TABLE" and any(
    - 16        isinstance(prop, exp.TemporaryProperty)
    - 17        for prop in (properties.expressions if properties else [])
    - 18    ):
    - 19        return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
    - 20    return create_with_partitions_sql(self, e)
    - 21
    - 22
    - 23def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
    - 24    keys = self.sql(expression.args["keys"])
    - 25    values = self.sql(expression.args["values"])
    - 26    return f"MAP_FROM_ARRAYS({keys}, {values})"
    - 27
    - 28
    - 29def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
    - 30    this = self.sql(expression, "this")
    - 31    time_format = self.format_time(expression)
    - 32    if time_format == Hive.date_format:
    - 33        return f"TO_DATE({this})"
    - 34    return f"TO_DATE({this}, {time_format})"
    - 35
    - 36
    - 37def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
    - 38    scale = expression.args.get("scale")
    - 39    timestamp = self.sql(expression, "this")
    - 40    if scale is None:
    - 41        return f"FROM_UNIXTIME({timestamp})"
    - 42    if scale == exp.UnixToTime.SECONDS:
    - 43        return f"TIMESTAMP_SECONDS({timestamp})"
    - 44    if scale == exp.UnixToTime.MILLIS:
    - 45        return f"TIMESTAMP_MILLIS({timestamp})"
    - 46    if scale == exp.UnixToTime.MICROS:
    - 47        return f"TIMESTAMP_MICROS({timestamp})"
    - 48
    - 49    raise ValueError("Improper scale for timestamp")
    - 50
    - 51
    - 52class Spark(Hive):
    - 53    class Parser(Hive.Parser):
    - 54        FUNCTIONS = {
    - 55            **Hive.Parser.FUNCTIONS,  # type: ignore
    - 56            "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
    - 57            "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
    - 58            "LEFT": lambda args: exp.Substring(
    - 59                this=seq_get(args, 0),
    - 60                start=exp.Literal.number(1),
    - 61                length=seq_get(args, 1),
    - 62            ),
    - 63            "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
    - 64                this=seq_get(args, 0),
    - 65                expression=seq_get(args, 1),
    - 66            ),
    - 67            "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
    - 68                this=seq_get(args, 0),
    - 69                expression=seq_get(args, 1),
    - 70            ),
    - 71            "RIGHT": lambda args: exp.Substring(
    - 72                this=seq_get(args, 0),
    - 73                start=exp.Sub(
    - 74                    this=exp.Length(this=seq_get(args, 0)),
    - 75                    expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
    - 76                ),
    - 77                length=seq_get(args, 1),
    - 78            ),
    - 79            "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
    - 80            "BOOLEAN": lambda args: exp.Cast(
    - 81                this=seq_get(args, 0), to=exp.DataType.build("boolean")
    - 82            ),
    - 83            "IIF": exp.If.from_arg_list,
    - 84            "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
    - 85            "AGGREGATE": exp.Reduce.from_arg_list,
    - 86            "DAYOFWEEK": lambda args: exp.DayOfWeek(
    - 87                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 88            ),
    - 89            "DAYOFMONTH": lambda args: exp.DayOfMonth(
    - 90                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 91            ),
    - 92            "DAYOFYEAR": lambda args: exp.DayOfYear(
    - 93                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 94            ),
    - 95            "WEEKOFYEAR": lambda args: exp.WeekOfYear(
    - 96                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 97            ),
    - 98            "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
    - 99            "DATE_TRUNC": lambda args: exp.TimestampTrunc(
    -100                this=seq_get(args, 1),
    -101                unit=exp.var(seq_get(args, 0)),
    -102            ),
    -103            "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
    -104            "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
    -105            "TIMESTAMP": lambda args: exp.Cast(
    -106                this=seq_get(args, 0), to=exp.DataType.build("timestamp")
    -107            ),
    -108        }
    -109
    -110        FUNCTION_PARSERS = {
    -111            **parser.Parser.FUNCTION_PARSERS,  # type: ignore
    -112            "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
    -113            "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
    -114            "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
    -115            "MERGE": lambda self: self._parse_join_hint("MERGE"),
    -116            "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
    -117            "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
    -118            "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
    -119            "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
    -120        }
    -121
    -122        def _parse_add_column(self) -> t.Optional[exp.Expression]:
    -123            return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
    -124
    -125        def _parse_drop_column(self) -> t.Optional[exp.Expression]:
    -126            return self._match_text_seq("DROP", "COLUMNS") and self.expression(
    -127                exp.Drop,
    -128                this=self._parse_schema(),
    -129                kind="COLUMNS",
    -130            )
    -131
    -132        def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
    -133            # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
    -134            if len(pivot_columns) == 1:
    -135                return [""]
    -136
    -137            names = []
    -138            for agg in pivot_columns:
    -139                if isinstance(agg, exp.Alias):
    -140                    names.append(agg.alias)
    -141                else:
    -142                    """
    -143                    This case corresponds to aggregations without aliases being used as suffixes
    -144                    (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
    -145                    be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
    -146                    Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
    -147
    -148                    Moreover, function names are lowercased in order to mimic Spark's naming scheme.
    -149                    """
    -150                    agg_all_unquoted = agg.transform(
    -151                        lambda node: exp.Identifier(this=node.name, quoted=False)
    -152                        if isinstance(node, exp.Identifier)
    -153                        else node
    -154                    )
    -155                    names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
    -156
    -157            return names
    -158
    -159    class Generator(Hive.Generator):
    -160        TYPE_MAPPING = {
    -161            **Hive.Generator.TYPE_MAPPING,  # type: ignore
    -162            exp.DataType.Type.TINYINT: "BYTE",
    -163            exp.DataType.Type.SMALLINT: "SHORT",
    -164            exp.DataType.Type.BIGINT: "LONG",
    -165        }
    -166
    -167        PROPERTIES_LOCATION = {
    -168            **Hive.Generator.PROPERTIES_LOCATION,  # type: ignore
    -169            exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
    -170            exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
    -171            exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
    -172            exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
    -173        }
    -174
    -175        TRANSFORMS = {
    -176            **Hive.Generator.TRANSFORMS,  # type: ignore
    -177            exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
    -178            exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
    -179            exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
    -180            exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
    -181            exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
    -182            exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
    -183            exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
    -184            exp.StrToDate: _str_to_date,
    -185            exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
    -186            exp.UnixToTime: _unix_to_time_sql,
    -187            exp.Create: _create_sql,
    -188            exp.Map: _map_sql,
    -189            exp.Reduce: rename_func("AGGREGATE"),
    -190            exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
    -191            exp.TimestampTrunc: lambda self, e: self.func(
    -192                "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
    -193            ),
    -194            exp.Trim: trim_sql,
    -195            exp.VariancePop: rename_func("VAR_POP"),
    -196            exp.DateFromParts: rename_func("MAKE_DATE"),
    -197            exp.LogicalOr: rename_func("BOOL_OR"),
    -198            exp.LogicalAnd: rename_func("BOOL_AND"),
    -199            exp.DayOfWeek: rename_func("DAYOFWEEK"),
    -200            exp.DayOfMonth: rename_func("DAYOFMONTH"),
    -201            exp.DayOfYear: rename_func("DAYOFYEAR"),
    -202            exp.WeekOfYear: rename_func("WEEKOFYEAR"),
    -203            exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
    -204        }
    -205        TRANSFORMS.pop(exp.ArraySort)
    -206        TRANSFORMS.pop(exp.ILike)
    -207
    -208        WRAP_DERIVED_VALUES = False
    -209        CREATE_FUNCTION_RETURN_AS = False
    -210
    -211        def cast_sql(self, expression: exp.Cast) -> str:
    -212            if isinstance(expression.this, exp.Cast) and expression.this.is_type(
    -213                exp.DataType.Type.JSON
    -214            ):
    -215                schema = f"'{self.sql(expression, 'to')}'"
    -216                return self.func("FROM_JSON", expression.this.this, schema)
    -217            if expression.to.is_type(exp.DataType.Type.JSON):
    -218                return self.func("TO_JSON", expression.this)
    -219
    -220            return super(Spark.Generator, self).cast_sql(expression)
    -221
    -222    class Tokenizer(Hive.Tokenizer):
    -223        HEX_STRINGS = [("X'", "'")]
    +                        
     1from __future__ import annotations
    + 2
    + 3import typing as t
    + 4
    + 5from sqlglot import exp
    + 6from sqlglot.dialects.spark2 import Spark2
    + 7from sqlglot.helper import seq_get
    + 8
    + 9
    +10def _parse_datediff(args: t.Sequence) -> exp.Expression:
    +11    """
    +12    Although Spark docs don't mention the "unit" argument, Spark3 added support for
    +13    it at some point. Databricks also supports this variation (see below).
    +14
    +15    For example, in spark-sql (v3.3.1):
    +16    - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
    +17    - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
    +18
    +19    See also:
    +20    - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
    +21    - https://docs.databricks.com/sql/language-manual/functions/datediff.html
    +22    """
    +23    unit = None
    +24    this = seq_get(args, 0)
    +25    expression = seq_get(args, 1)
    +26
    +27    if len(args) == 3:
    +28        unit = this
    +29        this = args[2]
    +30
    +31    return exp.DateDiff(
    +32        this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
    +33    )
    +34
    +35
    +36class Spark(Spark2):
    +37    class Parser(Spark2.Parser):
    +38        FUNCTIONS = {
    +39            **Spark2.Parser.FUNCTIONS,  # type: ignore
    +40            "DATEDIFF": _parse_datediff,
    +41        }
    +42
    +43    class Generator(Spark2.Generator):
    +44        TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
    +45        TRANSFORMS.pop(exp.DateDiff)
    +46
    +47        def datediff_sql(self, expression: exp.DateDiff) -> str:
    +48            unit = self.sql(expression, "unit")
    +49            end = self.sql(expression, "this")
    +50            start = self.sql(expression, "expression")
    +51
    +52            if unit:
    +53                return self.func("DATEDIFF", unit, start, end)
    +54
    +55            return self.func("DATEDIFF", end, start)
     
    @@ -312,184 +138,32 @@
    class - Spark(sqlglot.dialects.hive.Hive): + Spark(sqlglot.dialects.spark2.Spark2):
    -
     53class Spark(Hive):
    - 54    class Parser(Hive.Parser):
    - 55        FUNCTIONS = {
    - 56            **Hive.Parser.FUNCTIONS,  # type: ignore
    - 57            "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
    - 58            "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
    - 59            "LEFT": lambda args: exp.Substring(
    - 60                this=seq_get(args, 0),
    - 61                start=exp.Literal.number(1),
    - 62                length=seq_get(args, 1),
    - 63            ),
    - 64            "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
    - 65                this=seq_get(args, 0),
    - 66                expression=seq_get(args, 1),
    - 67            ),
    - 68            "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
    - 69                this=seq_get(args, 0),
    - 70                expression=seq_get(args, 1),
    - 71            ),
    - 72            "RIGHT": lambda args: exp.Substring(
    - 73                this=seq_get(args, 0),
    - 74                start=exp.Sub(
    - 75                    this=exp.Length(this=seq_get(args, 0)),
    - 76                    expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
    - 77                ),
    - 78                length=seq_get(args, 1),
    - 79            ),
    - 80            "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
    - 81            "BOOLEAN": lambda args: exp.Cast(
    - 82                this=seq_get(args, 0), to=exp.DataType.build("boolean")
    - 83            ),
    - 84            "IIF": exp.If.from_arg_list,
    - 85            "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
    - 86            "AGGREGATE": exp.Reduce.from_arg_list,
    - 87            "DAYOFWEEK": lambda args: exp.DayOfWeek(
    - 88                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 89            ),
    - 90            "DAYOFMONTH": lambda args: exp.DayOfMonth(
    - 91                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 92            ),
    - 93            "DAYOFYEAR": lambda args: exp.DayOfYear(
    - 94                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 95            ),
    - 96            "WEEKOFYEAR": lambda args: exp.WeekOfYear(
    - 97                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 98            ),
    - 99            "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
    -100            "DATE_TRUNC": lambda args: exp.TimestampTrunc(
    -101                this=seq_get(args, 1),
    -102                unit=exp.var(seq_get(args, 0)),
    -103            ),
    -104            "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
    -105            "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
    -106            "TIMESTAMP": lambda args: exp.Cast(
    -107                this=seq_get(args, 0), to=exp.DataType.build("timestamp")
    -108            ),
    -109        }
    -110
    -111        FUNCTION_PARSERS = {
    -112            **parser.Parser.FUNCTION_PARSERS,  # type: ignore
    -113            "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
    -114            "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
    -115            "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
    -116            "MERGE": lambda self: self._parse_join_hint("MERGE"),
    -117            "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
    -118            "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
    -119            "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
    -120            "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
    -121        }
    -122
    -123        def _parse_add_column(self) -> t.Optional[exp.Expression]:
    -124            return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
    -125
    -126        def _parse_drop_column(self) -> t.Optional[exp.Expression]:
    -127            return self._match_text_seq("DROP", "COLUMNS") and self.expression(
    -128                exp.Drop,
    -129                this=self._parse_schema(),
    -130                kind="COLUMNS",
    -131            )
    -132
    -133        def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
    -134            # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
    -135            if len(pivot_columns) == 1:
    -136                return [""]
    -137
    -138            names = []
    -139            for agg in pivot_columns:
    -140                if isinstance(agg, exp.Alias):
    -141                    names.append(agg.alias)
    -142                else:
    -143                    """
    -144                    This case corresponds to aggregations without aliases being used as suffixes
    -145                    (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
    -146                    be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
    -147                    Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
    -148
    -149                    Moreover, function names are lowercased in order to mimic Spark's naming scheme.
    -150                    """
    -151                    agg_all_unquoted = agg.transform(
    -152                        lambda node: exp.Identifier(this=node.name, quoted=False)
    -153                        if isinstance(node, exp.Identifier)
    -154                        else node
    -155                    )
    -156                    names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
    -157
    -158            return names
    -159
    -160    class Generator(Hive.Generator):
    -161        TYPE_MAPPING = {
    -162            **Hive.Generator.TYPE_MAPPING,  # type: ignore
    -163            exp.DataType.Type.TINYINT: "BYTE",
    -164            exp.DataType.Type.SMALLINT: "SHORT",
    -165            exp.DataType.Type.BIGINT: "LONG",
    -166        }
    -167
    -168        PROPERTIES_LOCATION = {
    -169            **Hive.Generator.PROPERTIES_LOCATION,  # type: ignore
    -170            exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
    -171            exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
    -172            exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
    -173            exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
    -174        }
    -175
    -176        TRANSFORMS = {
    -177            **Hive.Generator.TRANSFORMS,  # type: ignore
    -178            exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
    -179            exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
    -180            exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
    -181            exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
    -182            exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
    -183            exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
    -184            exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
    -185            exp.StrToDate: _str_to_date,
    -186            exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
    -187            exp.UnixToTime: _unix_to_time_sql,
    -188            exp.Create: _create_sql,
    -189            exp.Map: _map_sql,
    -190            exp.Reduce: rename_func("AGGREGATE"),
    -191            exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
    -192            exp.TimestampTrunc: lambda self, e: self.func(
    -193                "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
    -194            ),
    -195            exp.Trim: trim_sql,
    -196            exp.VariancePop: rename_func("VAR_POP"),
    -197            exp.DateFromParts: rename_func("MAKE_DATE"),
    -198            exp.LogicalOr: rename_func("BOOL_OR"),
    -199            exp.LogicalAnd: rename_func("BOOL_AND"),
    -200            exp.DayOfWeek: rename_func("DAYOFWEEK"),
    -201            exp.DayOfMonth: rename_func("DAYOFMONTH"),
    -202            exp.DayOfYear: rename_func("DAYOFYEAR"),
    -203            exp.WeekOfYear: rename_func("WEEKOFYEAR"),
    -204            exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
    -205        }
    -206        TRANSFORMS.pop(exp.ArraySort)
    -207        TRANSFORMS.pop(exp.ILike)
    -208
    -209        WRAP_DERIVED_VALUES = False
    -210        CREATE_FUNCTION_RETURN_AS = False
    -211
    -212        def cast_sql(self, expression: exp.Cast) -> str:
    -213            if isinstance(expression.this, exp.Cast) and expression.this.is_type(
    -214                exp.DataType.Type.JSON
    -215            ):
    -216                schema = f"'{self.sql(expression, 'to')}'"
    -217                return self.func("FROM_JSON", expression.this.this, schema)
    -218            if expression.to.is_type(exp.DataType.Type.JSON):
    -219                return self.func("TO_JSON", expression.this)
    -220
    -221            return super(Spark.Generator, self).cast_sql(expression)
    -222
    -223    class Tokenizer(Hive.Tokenizer):
    -224        HEX_STRINGS = [("X'", "'")]
    +            
    37class Spark(Spark2):
    +38    class Parser(Spark2.Parser):
    +39        FUNCTIONS = {
    +40            **Spark2.Parser.FUNCTIONS,  # type: ignore
    +41            "DATEDIFF": _parse_datediff,
    +42        }
    +43
    +44    class Generator(Spark2.Generator):
    +45        TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
    +46        TRANSFORMS.pop(exp.DateDiff)
    +47
    +48        def datediff_sql(self, expression: exp.DateDiff) -> str:
    +49            unit = self.sql(expression, "unit")
    +50            end = self.sql(expression, "this")
    +51            start = self.sql(expression, "expression")
    +52
    +53            if unit:
    +54                return self.func("DATEDIFF", unit, start, end)
    +55
    +56            return self.func("DATEDIFF", end, start)
     
    @@ -498,7 +172,11 @@
    Inherited Members
    -
    sqlglot.dialects.dialect.Dialect
    + +
    sqlglot.dialects.dialect.Dialect
    get_or_raise
    format_time
    parse
    @@ -518,117 +196,17 @@
    class - Spark.Parser(sqlglot.dialects.hive.Hive.Parser): + Spark.Parser(sqlglot.dialects.spark2.Spark2.Parser):
    -
     54    class Parser(Hive.Parser):
    - 55        FUNCTIONS = {
    - 56            **Hive.Parser.FUNCTIONS,  # type: ignore
    - 57            "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
    - 58            "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
    - 59            "LEFT": lambda args: exp.Substring(
    - 60                this=seq_get(args, 0),
    - 61                start=exp.Literal.number(1),
    - 62                length=seq_get(args, 1),
    - 63            ),
    - 64            "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
    - 65                this=seq_get(args, 0),
    - 66                expression=seq_get(args, 1),
    - 67            ),
    - 68            "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
    - 69                this=seq_get(args, 0),
    - 70                expression=seq_get(args, 1),
    - 71            ),
    - 72            "RIGHT": lambda args: exp.Substring(
    - 73                this=seq_get(args, 0),
    - 74                start=exp.Sub(
    - 75                    this=exp.Length(this=seq_get(args, 0)),
    - 76                    expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
    - 77                ),
    - 78                length=seq_get(args, 1),
    - 79            ),
    - 80            "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
    - 81            "BOOLEAN": lambda args: exp.Cast(
    - 82                this=seq_get(args, 0), to=exp.DataType.build("boolean")
    - 83            ),
    - 84            "IIF": exp.If.from_arg_list,
    - 85            "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
    - 86            "AGGREGATE": exp.Reduce.from_arg_list,
    - 87            "DAYOFWEEK": lambda args: exp.DayOfWeek(
    - 88                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 89            ),
    - 90            "DAYOFMONTH": lambda args: exp.DayOfMonth(
    - 91                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 92            ),
    - 93            "DAYOFYEAR": lambda args: exp.DayOfYear(
    - 94                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 95            ),
    - 96            "WEEKOFYEAR": lambda args: exp.WeekOfYear(
    - 97                this=exp.TsOrDsToDate(this=seq_get(args, 0)),
    - 98            ),
    - 99            "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
    -100            "DATE_TRUNC": lambda args: exp.TimestampTrunc(
    -101                this=seq_get(args, 1),
    -102                unit=exp.var(seq_get(args, 0)),
    -103            ),
    -104            "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
    -105            "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
    -106            "TIMESTAMP": lambda args: exp.Cast(
    -107                this=seq_get(args, 0), to=exp.DataType.build("timestamp")
    -108            ),
    -109        }
    -110
    -111        FUNCTION_PARSERS = {
    -112            **parser.Parser.FUNCTION_PARSERS,  # type: ignore
    -113            "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
    -114            "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
    -115            "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
    -116            "MERGE": lambda self: self._parse_join_hint("MERGE"),
    -117            "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
    -118            "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
    -119            "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
    -120            "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
    -121        }
    -122
    -123        def _parse_add_column(self) -> t.Optional[exp.Expression]:
    -124            return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
    -125
    -126        def _parse_drop_column(self) -> t.Optional[exp.Expression]:
    -127            return self._match_text_seq("DROP", "COLUMNS") and self.expression(
    -128                exp.Drop,
    -129                this=self._parse_schema(),
    -130                kind="COLUMNS",
    -131            )
    -132
    -133        def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
    -134            # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
    -135            if len(pivot_columns) == 1:
    -136                return [""]
    -137
    -138            names = []
    -139            for agg in pivot_columns:
    -140                if isinstance(agg, exp.Alias):
    -141                    names.append(agg.alias)
    -142                else:
    -143                    """
    -144                    This case corresponds to aggregations without aliases being used as suffixes
    -145                    (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
    -146                    be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
    -147                    Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
    -148
    -149                    Moreover, function names are lowercased in order to mimic Spark's naming scheme.
    -150                    """
    -151                    agg_all_unquoted = agg.transform(
    -152                        lambda node: exp.Identifier(this=node.name, quoted=False)
    -153                        if isinstance(node, exp.Identifier)
    -154                        else node
    -155                    )
    -156                    names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
    -157
    -158            return names
    +            
    38    class Parser(Spark2.Parser):
    +39        FUNCTIONS = {
    +40            **Spark2.Parser.FUNCTIONS,  # type: ignore
    +41            "DATEDIFF": _parse_datediff,
    +42        }
     
    @@ -679,74 +257,25 @@ Default: "nulls_are_small"
    class - Spark.Generator(sqlglot.dialects.hive.Hive.Generator): + Spark.Generator(sqlglot.dialects.spark2.Spark2.Generator):
    -
    160    class Generator(Hive.Generator):
    -161        TYPE_MAPPING = {
    -162            **Hive.Generator.TYPE_MAPPING,  # type: ignore
    -163            exp.DataType.Type.TINYINT: "BYTE",
    -164            exp.DataType.Type.SMALLINT: "SHORT",
    -165            exp.DataType.Type.BIGINT: "LONG",
    -166        }
    -167
    -168        PROPERTIES_LOCATION = {
    -169            **Hive.Generator.PROPERTIES_LOCATION,  # type: ignore
    -170            exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
    -171            exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
    -172            exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
    -173            exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
    -174        }
    -175
    -176        TRANSFORMS = {
    -177            **Hive.Generator.TRANSFORMS,  # type: ignore
    -178            exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
    -179            exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
    -180            exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
    -181            exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
    -182            exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
    -183            exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
    -184            exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
    -185            exp.StrToDate: _str_to_date,
    -186            exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
    -187            exp.UnixToTime: _unix_to_time_sql,
    -188            exp.Create: _create_sql,
    -189            exp.Map: _map_sql,
    -190            exp.Reduce: rename_func("AGGREGATE"),
    -191            exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
    -192            exp.TimestampTrunc: lambda self, e: self.func(
    -193                "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
    -194            ),
    -195            exp.Trim: trim_sql,
    -196            exp.VariancePop: rename_func("VAR_POP"),
    -197            exp.DateFromParts: rename_func("MAKE_DATE"),
    -198            exp.LogicalOr: rename_func("BOOL_OR"),
    -199            exp.LogicalAnd: rename_func("BOOL_AND"),
    -200            exp.DayOfWeek: rename_func("DAYOFWEEK"),
    -201            exp.DayOfMonth: rename_func("DAYOFMONTH"),
    -202            exp.DayOfYear: rename_func("DAYOFYEAR"),
    -203            exp.WeekOfYear: rename_func("WEEKOFYEAR"),
    -204            exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
    -205        }
    -206        TRANSFORMS.pop(exp.ArraySort)
    -207        TRANSFORMS.pop(exp.ILike)
    -208
    -209        WRAP_DERIVED_VALUES = False
    -210        CREATE_FUNCTION_RETURN_AS = False
    -211
    -212        def cast_sql(self, expression: exp.Cast) -> str:
    -213            if isinstance(expression.this, exp.Cast) and expression.this.is_type(
    -214                exp.DataType.Type.JSON
    -215            ):
    -216                schema = f"'{self.sql(expression, 'to')}'"
    -217                return self.func("FROM_JSON", expression.this.this, schema)
    -218            if expression.to.is_type(exp.DataType.Type.JSON):
    -219                return self.func("TO_JSON", expression.this)
    -220
    -221            return super(Spark.Generator, self).cast_sql(expression)
    +            
    44    class Generator(Spark2.Generator):
    +45        TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
    +46        TRANSFORMS.pop(exp.DateDiff)
    +47
    +48        def datediff_sql(self, expression: exp.DateDiff) -> str:
    +49            unit = self.sql(expression, "unit")
    +50            end = self.sql(expression, "this")
    +51            start = self.sql(expression, "expression")
    +52
    +53            if unit:
    +54                return self.func("DATEDIFF", unit, start, end)
    +55
    +56            return self.func("DATEDIFF", end, start)
     
    @@ -794,27 +323,26 @@ Default: True
    -
    - +
    +
    def - cast_sql(self, expression: sqlglot.expressions.Cast) -> str: + datediff_sql(self, expression: sqlglot.expressions.DateDiff) -> str: - +
    - -
    212        def cast_sql(self, expression: exp.Cast) -> str:
    -213            if isinstance(expression.this, exp.Cast) and expression.this.is_type(
    -214                exp.DataType.Type.JSON
    -215            ):
    -216                schema = f"'{self.sql(expression, 'to')}'"
    -217                return self.func("FROM_JSON", expression.this.this, schema)
    -218            if expression.to.is_type(exp.DataType.Type.JSON):
    -219                return self.func("TO_JSON", expression.this)
    -220
    -221            return super(Spark.Generator, self).cast_sql(expression)
    +    
    +            
    48        def datediff_sql(self, expression: exp.DateDiff) -> str:
    +49            unit = self.sql(expression, "unit")
    +50            end = self.sql(expression, "this")
    +51            start = self.sql(expression, "expression")
    +52
    +53            if unit:
    +54                return self.func("DATEDIFF", unit, start, end)
    +55
    +56            return self.func("DATEDIFF", end, start)
     
    @@ -943,7 +471,7 @@ Default: True
    where_sql
    window_sql
    partition_by_sql
    -
    window_spec_sql
    +
    windowspec_sql
    withingroup_sql
    between_sql
    bracket_sql
    @@ -952,6 +480,7 @@ Default: True
    exists_sql
    case_sql
    constraint_sql
    +
    nextvaluefor_sql
    extract_sql
    trim_sql
    concat_sql
    @@ -1047,6 +576,10 @@ Default: True
    merge_sql
    tochar_sql
    +
    +
    sqlglot.dialects.hive.Hive.Generator
    arrayagg_sql
    @@ -1057,35 +590,6 @@ Default: True
    -
    - -
    - - class - Spark.Tokenizer(sqlglot.dialects.hive.Hive.Tokenizer): - - - -
    - -
    223    class Tokenizer(Hive.Tokenizer):
    -224        HEX_STRINGS = [("X'", "'")]
    -
    - - - - -
    -
    Inherited Members
    -
    - -
    -
    -