diff options
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r-- | sqlglot/dialects/spark.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 44bd12d..c662ab5 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -5,8 +5,14 @@ import typing as t from sqlglot import exp from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.hive import _parse_ignore_nulls -from sqlglot.dialects.spark2 import Spark2 +from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get +from sqlglot.transforms import ( + ctas_with_tmp_tables_to_create_tmp_view, + remove_unique_constraints, + preprocess, + move_partitioned_by_to_schema_columns, +) def _parse_datediff(args: t.List) -> exp.Expression: @@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression: ) +def _normalize_partition(e: exp.Expression) -> exp.Expression: + """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)""" + if isinstance(e, str): + return exp.to_identifier(e) + if isinstance(e, exp.Literal): + return exp.to_identifier(e.name) + return e + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -72,6 +87,17 @@ class Spark(Spark2): TRANSFORMS = { **Spark2.Generator.TRANSFORMS, + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_partitioned_by_to_schema_columns, + ] + ), + exp.PartitionedByProperty: lambda self, + e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this |