summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark2.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark2.py')
-rw-r--r--sqlglot/dialects/spark2.py64
1 files changed, 32 insertions, 32 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index fa55b51..60cf8e1 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
- format_time_lambda,
+ build_formatted_time,
is_parse_json,
pivot_column_names,
rename_func,
@@ -26,36 +26,37 @@ def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
values = expression.args.get("values")
if not keys or not values:
- return "MAP()"
+ return self.func("MAP")
- return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
+ return self.func("MAP_FROM_ARRAYS", keys, values)
-def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
+def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
- this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.DATE_FORMAT:
- return f"TO_DATE({this})"
- return f"TO_DATE({this}, {time_format})"
+ return self.func("TO_DATE", expression.this)
+ return self.func("TO_DATE", expression.this, time_format)
def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
+ timestamp = expression.this
+
if scale is None:
- return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
+ return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp"))
if scale == exp.UnixToTime.SECONDS:
- return f"TIMESTAMP_SECONDS({timestamp})"
+ return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
+ return self.func("TIMESTAMP_MILLIS", timestamp)
if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
+ return self.func("TIMESTAMP_MICROS", timestamp)
- return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
+ unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))
+ return self.func("TIMESTAMP_SECONDS", unix_seconds)
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@@ -116,16 +117,16 @@ class Spark2(Hive):
**Hive.Parser.FUNCTIONS,
"AGGREGATE": exp.Reduce.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "BOOLEAN": _parse_as_cast("boolean"),
- "DATE": _parse_as_cast("date"),
+ "BOOLEAN": _build_as_cast("boolean"),
+ "DATE": _build_as_cast("date"),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DOUBLE": _parse_as_cast("double"),
- "FLOAT": _parse_as_cast("float"),
+ "DOUBLE": _build_as_cast("double"),
+ "FLOAT": _build_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
this=exp.cast_unless(
seq_get(args, 0) or exp.Var(this=""),
@@ -134,17 +135,17 @@ class Spark2(Hive):
),
zone=seq_get(args, 1),
),
- "INT": _parse_as_cast("int"),
+ "INT": _build_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
- "STRING": _parse_as_cast("string"),
- "TIMESTAMP": _parse_as_cast("timestamp"),
+ "STRING": _build_as_cast("string"),
+ "TIMESTAMP": _build_as_cast("timestamp"),
"TO_TIMESTAMP": lambda args: (
- _parse_as_cast("timestamp")(args)
+ _build_as_cast("timestamp")(args)
if len(args) == 1
- else format_time_lambda(exp.StrToTime, "spark")(args)
+ else build_formatted_time(exp.StrToTime, "spark")(args)
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
@@ -187,6 +188,7 @@ class Spark2(Hive):
class Generator(Hive.Generator):
QUERY_HINTS = True
NVL2_SUPPORTED = True
+ CAN_IMPLEMENT_ARRAY_ANY = True
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
@@ -201,8 +203,9 @@ class Spark2(Hive):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.AtTimeZone: lambda self,
- e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.AtTimeZone: lambda self, e: self.func(
+ "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
+ ),
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: preprocess(
@@ -221,8 +224,9 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
- exp.FromTimeZone: lambda self,
- e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.FromTimeZone: lambda self, e: self.func(
+ "TO_UTC_TIMESTAMP", e.this, e.args.get("zone")
+ ),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@@ -236,8 +240,7 @@ class Spark2(Hive):
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
@@ -263,10 +266,7 @@ class Spark2(Hive):
args = []
for arg in expression.expressions:
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
- if isinstance(arg, exp.Bracket):
- args.append(exp.alias_(arg.this, arg.expressions[0].name))
- else:
- args.append(exp.alias_(arg.expression, arg.this.name))
+ args.append(exp.alias_(arg.expression, arg.this.name))
else:
args.append(arg)