diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dialects/spark.py | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r-- | sqlglot/dialects/spark.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 20c0fce..88b5ddc 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import rename_func, unit_to_var from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get @@ -78,6 +78,8 @@ class Spark(Spark2): return this class Generator(Spark2.Generator): + SUPPORTS_TO_NUMBER = True + TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", @@ -100,7 +102,7 @@ class Spark(Spark2): e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( - "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this + "DATEADD", unit_to_var(e), e.expression, e.this ), exp.TryCast: lambda self, e: ( self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) @@ -117,11 +119,10 @@ class Spark(Spark2): return self.function_fallback_sql(expression) def datediff_sql(self, expression: exp.DateDiff) -> str: - unit = self.sql(expression, "unit") end = self.sql(expression, "this") start = self.sql(expression, "expression") - if unit: - return self.func("DATEDIFF", unit, start, end) + if expression.unit: + return self.func("DATEDIFF", unit_to_var(expression), start, end) return self.func("DATEDIFF", end, start) |