summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 20c0fce..88b5ddc 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot.dialects.dialect import rename_func
+from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
@@ -78,6 +78,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
+ SUPPORTS_TO_NUMBER = True
+
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
@@ -100,7 +102,7 @@ class Spark(Spark2):
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
- "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
+ "DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
@@ -117,11 +119,10 @@ class Spark(Spark2):
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
- unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
- if unit:
- return self.func("DATEDIFF", unit, start, end)
+ if expression.unit:
+ return self.func("DATEDIFF", unit_to_var(expression), start, end)
return self.func("DATEDIFF", end, start)