diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-19 10:22:04 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-19 10:22:04 +0000 |
commit | 57c3067868d0a1da90ec0f2201dd91f031241274 (patch) | |
tree | 3b16819683e27ccbc7e7726675ab8d3e978fc8aa /sqlglot/dialects/hive.py | |
parent | Adding upstream version 11.3.6. (diff) | |
download | sqlglot-57c3067868d0a1da90ec0f2201dd91f031241274.tar.xz sqlglot-57c3067868d0a1da90ec0f2201dd91f031241274.zip |
Adding upstream version 11.4.1.upstream/11.4.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r-- | sqlglot/dialects/hive.py | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c4b8fa9..0110eee 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self, expression): +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( @@ -47,7 +49,7 @@ def _add_date_sql(self, expression): return self.func(func, expression.this, modified_increment.this) -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) @@ -56,21 +58,21 @@ def _date_diff_sql(self, expression): return f"{diff_sql}{multiplier_sql}" -def _array_sort(self, expression): +def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" -def _property_sql(self, expression): +def _property_sql(self: generator.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix(self, expression): +def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) -def _str_to_date(self, expression): +def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -78,7 +80,7 @@ def _str_to_date(self, expression): return f"CAST({this} AS DATE)" -def _str_to_time(self, expression): +def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -86,20 +88,22 @@ def _str_to_time(self, expression): return f"CAST({this} AS TIMESTAMP)" -def _time_format(self, expression): +def _time_format( + self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix +) -> t.Optional[str]: time_format = self.format_time(expression) if time_format == Hive.time_format: return None return time_format -def _time_to_str(self, expression): +def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) return f"DATE_FORMAT({this}, {time_format})" -def _to_date_sql(self, expression): +def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.time_format, Hive.date_format): @@ -107,7 +111,7 @@ def _to_date_sql(self, expression): return f"TO_DATE({this})" -def _unnest_to_explode_sql(self, expression): +def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: unnest = expression.this if isinstance(unnest, exp.Unnest): alias = unnest.args.get("alias") @@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression): exp.Lateral( this=udtf(this=expression), view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), + alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore ) ) for expression, column in zip(unnest.expressions, alias.columns if alias else []) @@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression): return self.join_sql(expression) -def _index_sql(self, expression): +def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") @@ -263,14 +267,15 @@ class Hive(Dialect): exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, - exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, @@ -333,13 +338,19 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } - def with_properties(self, properties): + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: + return self.func( + "COLLECT_LIST", + expression.this.this if isinstance(expression.this, exp.Order) else expression.this, + ) + + def with_properties(self, properties: exp.Properties) -> str: return self.properties( properties, prefix=self.seg("TBLPROPERTIES"), ) - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: if ( expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) and not expression.expressions |