diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/spark2.py | 64 |
1 files changed, 32 insertions, 32 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index fa55b51..60cf8e1 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( binary_from_function, - format_time_lambda, + build_formatted_time, is_parse_json, pivot_column_names, rename_func, @@ -26,36 +26,37 @@ def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: values = expression.args.get("values") if not keys or not values: - return "MAP()" + return self.func("MAP") - return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})" + return self.func("MAP_FROM_ARRAYS", keys, values) -def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: +def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.DATE_FORMAT: - return f"TO_DATE({this})" - return f"TO_DATE({this}, {time_format})" + return self.func("TO_DATE", expression.this) + return self.func("TO_DATE", expression.this, time_format) def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") + timestamp = expression.this + if scale is None: - return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)" + return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp")) if scale == exp.UnixToTime.SECONDS: - return f"TIMESTAMP_SECONDS({timestamp})" + return self.func("TIMESTAMP_SECONDS", timestamp) if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" + return self.func("TIMESTAMP_MILLIS", timestamp) if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" + return self.func("TIMESTAMP_MICROS", timestamp) - return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))" + unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)) + return self.func("TIMESTAMP_SECONDS", unix_seconds) def _unalias_pivot(expression: exp.Expression) -> exp.Expression: @@ -116,16 +117,16 @@ class Spark2(Hive): **Hive.Parser.FUNCTIONS, "AGGREGATE": exp.Reduce.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "BOOLEAN": _parse_as_cast("boolean"), - "DATE": _parse_as_cast("date"), + "BOOLEAN": _build_as_cast("boolean"), + "DATE": _build_as_cast("date"), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)) ), "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DOUBLE": _parse_as_cast("double"), - "FLOAT": _parse_as_cast("float"), + "DOUBLE": _build_as_cast("double"), + "FLOAT": _build_as_cast("float"), "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone( this=exp.cast_unless( seq_get(args, 0) or exp.Var(this=""), @@ -134,17 +135,17 @@ class Spark2(Hive): ), zone=seq_get(args, 1), ), - "INT": _parse_as_cast("int"), + "INT": _build_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift), "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), - "STRING": _parse_as_cast("string"), - "TIMESTAMP": _parse_as_cast("timestamp"), + "STRING": _build_as_cast("string"), + "TIMESTAMP": _build_as_cast("timestamp"), "TO_TIMESTAMP": lambda args: ( - _parse_as_cast("timestamp")(args) + _build_as_cast("timestamp")(args) if len(args) == 1 - else format_time_lambda(exp.StrToTime, "spark")(args) + else build_formatted_time(exp.StrToTime, "spark")(args) ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone( @@ -187,6 +188,7 @@ class Spark2(Hive): class Generator(Hive.Generator): QUERY_HINTS = True NVL2_SUPPORTED = True + CAN_IMPLEMENT_ARRAY_ANY = True PROPERTIES_LOCATION = { **Hive.Generator.PROPERTIES_LOCATION, @@ -201,8 +203,9 @@ class Spark2(Hive): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.AtTimeZone: lambda self, - e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.AtTimeZone: lambda self, e: self.func( + "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone") + ), exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.Create: preprocess( @@ -221,8 +224,9 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.FromTimeZone: lambda self, - e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.FromTimeZone: lambda self, e: self.func( + "TO_UTC_TIMESTAMP", e.this, e.args.get("zone") + ), exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, @@ -236,8 +240,7 @@ class Spark2(Hive): e.args.get("position"), ), exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this ), @@ -263,10 +266,7 @@ class Spark2(Hive): args = [] for arg in expression.expressions: if isinstance(arg, self.KEY_VALUE_DEFINITIONS): - if isinstance(arg, exp.Bracket): - args.append(exp.alias_(arg.this, arg.expressions[0].name)) - else: - args.append(exp.alias_(arg.expression, arg.this.name)) + args.append(exp.alias_(arg.expression, arg.this.name)) else: args.append(arg) |