diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/spark2.py | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 3dc9838..4130375 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( binary_from_function, - create_with_partitions_sql, format_time_lambda, is_parse_json, move_insert_cte_sql, @@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self: Spark2.Generator, e: exp.Create) -> str: - kind = e.args["kind"] - properties = e.args.get("properties") - - if ( - kind.upper() == "TABLE" - and e.expression - and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - ): - return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return create_with_partitions_sql(self, e) - - def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: keys = expression.args.get("keys") values = expression.args.get("values") @@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: class Spark2(Hive): class Parser(Hive.Parser): + TRIM_PATTERN_FIRST = True + FUNCTIONS = { **Hive.Parser.FUNCTIONS, "AGGREGATE": exp.Reduce.from_arg_list, @@ -192,7 +177,6 @@ class Spark2(Hive): exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.Create: _create_sql, exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -236,6 +220,12 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" |