summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 44bd12d..c662ab5 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -5,8 +5,14 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.hive import _parse_ignore_nulls
-from sqlglot.dialects.spark2 import Spark2
+from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
+from sqlglot.transforms import (
+ ctas_with_tmp_tables_to_create_tmp_view,
+ remove_unique_constraints,
+ preprocess,
+ move_partitioned_by_to_schema_columns,
+)
def _parse_datediff(args: t.List) -> exp.Expression:
@@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression:
)
+def _normalize_partition(e: exp.Expression) -> exp.Expression:
+ """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
+ if isinstance(e, str):
+ return exp.to_identifier(e)
+ if isinstance(e, exp.Literal):
+ return exp.to_identifier(e.name)
+ return e
+
+
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@@ -72,6 +87,17 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
+ exp.Create: preprocess(
+ [
+ remove_unique_constraints,
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
+ e, temporary_storage_provider
+ ),
+ move_partitioned_by_to_schema_columns,
+ ]
+ ),
+ exp.PartitionedByProperty: lambda self,
+ e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this