summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/dialects/spark2.py26
1 files changed, 8 insertions, 18 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 3dc9838..4130375 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
- create_with_partitions_sql,
format_time_lambda,
is_parse_json,
move_insert_cte_sql,
@@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
-
- if (
- kind.upper() == "TABLE"
- and e.expression
- and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- )
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
-
-
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
keys = expression.args.get("keys")
values = expression.args.get("values")
@@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
class Spark2(Hive):
class Parser(Hive.Parser):
+ TRIM_PATTERN_FIRST = True
+
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
"AGGREGATE": exp.Reduce.from_arg_list,
@@ -192,7 +177,6 @@ class Spark2(Hive):
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.Create: _create_sql,
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -236,6 +220,12 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
+ def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
+ # spark2, spark, Databricks require a storage provider for temporary tables
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
+ expression.args["properties"].append("expressions", provider)
+ return expression
+
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"