summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/dialects/spark2.py29
1 files changed, 22 insertions, 7 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 9378d99..fa55b51 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
+from sqlglot.transforms import (
+ preprocess,
+ remove_unique_constraints,
+ ctas_with_tmp_tables_to_create_tmp_view,
+ move_schema_columns_to_partitioned_by,
+)
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
@@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
+ # spark2, spark, Databricks require a storage provider for temporary tables
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
+ expression.args["properties"].append("expressions", provider)
+ return expression
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True
@@ -121,7 +134,6 @@ class Spark2(Hive):
),
zone=seq_get(args, 1),
),
- "IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
@@ -193,6 +205,15 @@ class Spark2(Hive):
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.Create: preprocess(
+ [
+ remove_unique_constraints,
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
+ e, temporary_storage_provider
+ ),
+ move_schema_columns_to_partitioned_by,
+ ]
+ ),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -251,12 +272,6 @@ class Spark2(Hive):
return self.func("STRUCT", *args)
- def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
- # spark2, spark, Databricks require a storage provider for temporary tables
- provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
- expression.args["properties"].append("expressions", provider)
- return expression
-
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"