diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/spark2.py | 96 |
1 files changed, 63 insertions, 33 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 584671f..912b86b 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -3,7 +3,12 @@ from __future__ import annotations import typing as t from sqlglot import exp, parser, transforms -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql +from sqlglot.dialects.dialect import ( + create_with_partitions_sql, + pivot_column_names, + rename_func, + trim_sql, +) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: return f"MAP_FROM_ARRAYS({keys}, {values})" -def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]: +def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) @@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: raise ValueError("Improper scale for timestamp") +def _unalias_pivot(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a + pivoted source in a subquery with the same alias to preserve the query's semantics. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv") + >>> print(_unalias_pivot(expr).sql(dialect="spark")) + SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv + """ + if isinstance(expression, exp.From) and expression.this.args.get("pivots"): + pivot = expression.this.args["pivots"][0] + if pivot.alias: + alias = pivot.args["alias"].pop() + return exp.From( + this=expression.this.replace( + exp.select("*").from_(expression.this.copy()).subquery(alias=alias) + ) + ) + + return expression + + +def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow the column referenced in the PIVOT's field to be qualified, + so we need to unqualify it. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))") + >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark")) + SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) + """ + if isinstance(expression, exp.Pivot): + expression.args["field"].transform( + lambda node: exp.column(node.output_name, quoted=node.this.quoted) + if isinstance(node, exp.Column) + else node, + copy=False, + ) + + return expression + + class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, # type: ignore + **Hive.Parser.FUNCTIONS, "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( @@ -110,7 +161,7 @@ class Spark2(Hive): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -131,43 +182,21 @@ class Spark2(Hive): kind="COLUMNS", ) - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - # Spark doesn't add a suffix to the pivot columns when there's a single aggregation - if len(pivot_columns) == 1: + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + if len(aggregations) == 1: return [""] - - names = [] - for agg in pivot_columns: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - - Moreover, function names are lowercased in order to mimic Spark's naming scheme. - """ - agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) - - return names + return pivot_column_names(aggregations, dialect="spark") class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, # type: ignore + **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + **Hive.Generator.PROPERTIES_LOCATION, exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, @@ -175,7 +204,7 @@ class Spark2(Hive): } TRANSFORMS = { - **Hive.Generator.TRANSFORMS, # type: ignore + **Hive.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", @@ -188,11 +217,12 @@ class Spark2(Hive): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", + exp.From: transforms.preprocess([_unalias_pivot]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, - exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]), + exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), exp.Reduce: rename_func("AGGREGATE"), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", |