summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py240
1 files changed, 36 insertions, 204 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index a3e4cce..939f2fd 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -2,222 +2,54 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parser
-from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
-from sqlglot.dialects.hive import Hive
+from sqlglot import exp
+from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
-def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
+def _parse_datediff(args: t.Sequence) -> exp.Expression:
+ """
+ Although Spark docs don't mention the "unit" argument, Spark3 added support for
+ it at some point. Databricks also supports this variation (see below).
- if kind.upper() == "TABLE" and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
+ For example, in spark-sql (v3.3.1):
+ - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
+ - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
+ See also:
+ - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
+ - https://docs.databricks.com/sql/language-manual/functions/datediff.html
+ """
+ unit = None
+ this = seq_get(args, 0)
+ expression = seq_get(args, 1)
-def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
- keys = self.sql(expression.args["keys"])
- values = self.sql(expression.args["values"])
- return f"MAP_FROM_ARRAYS({keys}, {values})"
+ if len(args) == 3:
+ unit = this
+ this = args[2]
+ return exp.DateDiff(
+ this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
+ )
-def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
- this = self.sql(expression, "this")
- time_format = self.format_time(expression)
- if time_format == Hive.date_format:
- return f"TO_DATE({this})"
- return f"TO_DATE({this}, {time_format})"
-
-def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
- scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
- if scale is None:
- return f"FROM_UNIXTIME({timestamp})"
- if scale == exp.UnixToTime.SECONDS:
- return f"TIMESTAMP_SECONDS({timestamp})"
- if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
- if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
-
- raise ValueError("Improper scale for timestamp")
-
-
-class Spark(Hive):
- class Parser(Hive.Parser):
+class Spark(Spark2):
+ class Parser(Spark2.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS, # type: ignore
- "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
- "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "LEFT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Literal.number(1),
- length=seq_get(args, 1),
- ),
- "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "RIGHT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Sub(
- this=exp.Length(this=seq_get(args, 0)),
- expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
- ),
- length=seq_get(args, 1),
- ),
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "BOOLEAN": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("boolean")
- ),
- "IIF": exp.If.from_arg_list,
- "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
- "AGGREGATE": exp.Reduce.from_arg_list,
- "DAYOFWEEK": lambda args: exp.DayOfWeek(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFYEAR": lambda args: exp.DayOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "WEEKOFYEAR": lambda args: exp.WeekOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1),
- unit=exp.var(seq_get(args, 0)),
- ),
- "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
- "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
- "TIMESTAMP": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("timestamp")
- ),
- }
-
- FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
- "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
- "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
- "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
- "MERGE": lambda self: self._parse_join_hint("MERGE"),
- "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
- "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
- "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
- "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
- }
-
- def _parse_add_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
-
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("DROP", "COLUMNS") and self.expression(
- exp.Drop,
- this=self._parse_schema(),
- kind="COLUMNS",
- )
-
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
- if len(pivot_columns) == 1:
- return [""]
-
- names = []
- for agg in pivot_columns:
- if isinstance(agg, exp.Alias):
- names.append(agg.alias)
- else:
- """
- This case corresponds to aggregations without aliases being used as suffixes
- (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
- be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
- Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
-
- Moreover, function names are lowercased in order to mimic Spark's naming scheme.
- """
- agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
- )
- names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
-
- return names
-
- class Generator(Hive.Generator):
- TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.TINYINT: "BYTE",
- exp.DataType.Type.SMALLINT: "SHORT",
- exp.DataType.Type.BIGINT: "LONG",
- }
-
- PROPERTIES_LOCATION = {
- **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
- exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
- }
-
- TRANSFORMS = {
- **Hive.Generator.TRANSFORMS, # type: ignore
- exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
- exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
- exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
- exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
- exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
- exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time_sql,
- exp.Create: _create_sql,
- exp.Map: _map_sql,
- exp.Reduce: rename_func("AGGREGATE"),
- exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
- exp.Trim: trim_sql,
- exp.VariancePop: rename_func("VAR_POP"),
- exp.DateFromParts: rename_func("MAKE_DATE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
- exp.LogicalAnd: rename_func("BOOL_AND"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
- exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ **Spark2.Parser.FUNCTIONS, # type: ignore
+ "DATEDIFF": _parse_datediff,
}
- TRANSFORMS.pop(exp.ArraySort)
- TRANSFORMS.pop(exp.ILike)
- WRAP_DERIVED_VALUES = False
- CREATE_FUNCTION_RETURN_AS = False
+ class Generator(Spark2.Generator):
+ TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
+ TRANSFORMS.pop(exp.DateDiff)
- def cast_sql(self, expression: exp.Cast) -> str:
- if isinstance(expression.this, exp.Cast) and expression.this.is_type(
- exp.DataType.Type.JSON
- ):
- schema = f"'{self.sql(expression, 'to')}'"
- return self.func("FROM_JSON", expression.this.this, schema)
- if expression.to.is_type(exp.DataType.Type.JSON):
- return self.func("TO_JSON", expression.this)
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
+ unit = self.sql(expression, "unit")
+ end = self.sql(expression, "this")
+ start = self.sql(expression, "expression")
- return super(Spark.Generator, self).cast_sql(expression)
+ if unit:
+ return self.func("DATEDIFF", unit, start, end)
- class Tokenizer(Hive.Tokenizer):
- HEX_STRINGS = [("X'", "'")]
+ return self.func("DATEDIFF", end, start)