summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py54
1 files changed, 46 insertions, 8 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index c271f6f..a3e4cce 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,13 +1,15 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self, e):
- kind = e.args.get("kind")
+def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+ kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
@@ -18,13 +20,13 @@ def _create_sql(self, e):
return create_with_partitions_sql(self, e)
-def _map_sql(self, expression):
+def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
-def _str_to_date(self, expression):
+def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
@@ -32,7 +34,7 @@ def _str_to_date(self, expression):
return f"TO_DATE({this}, {time_format})"
-def _unix_to_time(self, expression):
+def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@@ -75,7 +77,11 @@ class Spark(Hive):
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "BOOLEAN": lambda args: exp.Cast(
+ this=seq_get(args, 0), to=exp.DataType.build("boolean")
+ ),
"IIF": exp.If.from_arg_list,
+ "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@@ -89,11 +95,16 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
+ "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
+ "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "TIMESTAMP": lambda args: exp.Cast(
+ this=seq_get(args, 0), to=exp.DataType.build("timestamp")
+ ),
}
FUNCTION_PARSERS = {
@@ -108,16 +119,43 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
- def _parse_add_column(self):
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
- def _parse_drop_column(self):
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
+ if len(pivot_columns) == 1:
+ return [""]
+
+ names = []
+ for agg in pivot_columns:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+
+ Moreover, function names are lowercased in order to mimic Spark's naming scheme.
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
+
+ return names
+
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
@@ -145,7 +183,7 @@ class Spark(Hive):
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time,
+ exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),