diff options
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r-- | sqlglot/dialects/spark.py | 54 |
1 files changed, 46 insertions, 8 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c271f6f..a3e4cce 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,13 +1,15 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self, e): - kind = e.args.get("kind") +def _create_sql(self: Hive.Generator, e: exp.Create) -> str: + kind = e.args["kind"] properties = e.args.get("properties") if kind.upper() == "TABLE" and any( @@ -18,13 +20,13 @@ def _create_sql(self, e): return create_with_partitions_sql(self, e) -def _map_sql(self, expression): +def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: keys = self.sql(expression.args["keys"]) values = self.sql(expression.args["values"]) return f"MAP_FROM_ARRAYS({keys}, {values})" -def _str_to_date(self, expression): +def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.date_format: @@ -32,7 +34,7 @@ def _str_to_date(self, expression): return f"TO_DATE({this}, {time_format})" -def _unix_to_time(self, expression): +def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale is None: @@ -75,7 +77,11 @@ class Spark(Hive): length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "BOOLEAN": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("boolean") + ), "IIF": exp.If.from_arg_list, + "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")), "AGGREGATE": exp.Reduce.from_arg_list, "DAYOFWEEK": lambda args: exp.DayOfWeek( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -89,11 +95,16 @@ class Spark(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), + "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)), ), + "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "TIMESTAMP": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("timestamp") + ), } FUNCTION_PARSERS = { @@ -108,16 +119,43 @@ class Spark(Hive): "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), } - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("DROP", "COLUMNS") and self.expression( exp.Drop, this=self._parse_schema(), kind="COLUMNS", ) + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + # Spark doesn't add a suffix to the pivot columns when there's a single aggregation + if len(pivot_columns) == 1: + return [""] + + names = [] + for agg in pivot_columns: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + + Moreover, function names are lowercased in order to mimic Spark's naming scheme. + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) + + return names + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, # type: ignore @@ -145,7 +183,7 @@ class Spark(Hive): exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.Create: _create_sql, exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), |