diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/spark2.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 9378d99..fa55b51 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get +from sqlglot.transforms import ( + preprocess, + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, +) def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: @@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def temporary_storage_provider(expression: exp.Expression) -> exp.Expression: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + + class Spark2(Hive): class Parser(Hive.Parser): TRIM_PATTERN_FIRST = True @@ -121,7 +134,6 @@ class Spark2(Hive): ), zone=seq_get(args, 1), ), - "IIF": exp.If.from_arg_list, "INT": _parse_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, @@ -193,6 +205,15 @@ class Spark2(Hive): e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_schema_columns_to_partitioned_by, + ] + ), exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -251,12 +272,6 @@ class Spark2(Hive): return self.func("STRUCT", *args) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # spark2, spark, Databricks require a storage provider for temporary tables - provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) - expression.args["properties"].append("expressions", provider) - return expression - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" |