summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py47
1 files changed, 44 insertions, 3 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index e828b9b..0212352 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
)
+def _build_dateadd(args: t.List) -> exp.Expression:
+ expression = seq_get(args, 1)
+
+ if len(args) == 2:
+ # DATE_ADD(startDate, numDays INTEGER)
+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
+ return exp.TsOrDsAdd(
+ this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
+ )
+
+ # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
+ return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
+
+
def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
@@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
return e
+def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
+ if not expression.unit or (
+ isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
+ ):
+ # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
+ return self.func("DATE_ADD", expression.this, expression.expression)
+
+ this = self.func(
+ "DATE_ADD",
+ unit_to_var(expression),
+ expression.expression,
+ expression.this,
+ )
+
+ if isinstance(expression, exp.TsOrDsAdd):
+ # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
+ # in other dialects
+ return_type = expression.return_type
+ if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
+ this = f"CAST({this} AS {return_type})"
+
+ return this
+
+
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@@ -62,6 +101,9 @@ class Spark(Spark2):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
+ "DATE_ADD": _build_dateadd,
+ "DATEADD": _build_dateadd,
+ "TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@@ -111,9 +153,8 @@ class Spark(Spark2):
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
- exp.TimestampAdd: lambda self, e: self.func(
- "DATEADD", unit_to_var(e), e.expression, e.this
- ),
+ exp.TsOrDsAdd: _dateadd_sql,
+ exp.TimestampAdd: _dateadd_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),