diff options
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r-- | sqlglot/dialects/spark.py | 47 |
1 files changed, 44 insertions, 3 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index e828b9b..0212352 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression: ) +def _build_dateadd(args: t.List) -> exp.Expression: + expression = seq_get(args, 1) + + if len(args) == 2: + # DATE_ADD(startDate, numDays INTEGER) + # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html + return exp.TsOrDsAdd( + this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY") + ) + + # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr) + # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html + return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0)) + + def _normalize_partition(e: exp.Expression) -> exp.Expression: """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)""" if isinstance(e, str): @@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression: return e +def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str: + if not expression.unit or ( + isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY" + ): + # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB + return self.func("DATE_ADD", expression.this, expression.expression) + + this = self.func( + "DATE_ADD", + unit_to_var(expression), + expression.expression, + expression.this, + ) + + if isinstance(expression, exp.TsOrDsAdd): + # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not + # in other dialects + return_type = expression.return_type + if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME): + this = f"CAST({this} AS {return_type})" + + return this + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -62,6 +101,9 @@ class Spark(Spark2): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), + "DATE_ADD": _build_dateadd, + "DATEADD": _build_dateadd, + "TIMESTAMPADD": _build_dateadd, "DATEDIFF": _build_datediff, "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), @@ -111,9 +153,8 @@ class Spark(Spark2): exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), - exp.TimestampAdd: lambda self, e: self.func( - "DATEADD", unit_to_var(e), e.expression, e.this - ), + exp.TsOrDsAdd: _dateadd_sql, + exp.TimestampAdd: _dateadd_sql, exp.TryCast: lambda self, e: ( self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) ), |