summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/hive.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r--sqlglot/dialects/hive.py43
1 files changed, 27 insertions, 16 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index c4b8fa9..0110eee 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self, expression):
+def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
@@ -47,7 +49,7 @@ def _add_date_sql(self, expression):
return self.func(func, expression.this, modified_increment.this)
-def _date_diff_sql(self, expression):
+def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
@@ -56,21 +58,21 @@ def _date_diff_sql(self, expression):
return f"{diff_sql}{multiplier_sql}"
-def _array_sort(self, expression):
+def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
-def _property_sql(self, expression):
+def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix(self, expression):
+def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
-def _str_to_date(self, expression):
+def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -78,7 +80,7 @@ def _str_to_date(self, expression):
return f"CAST({this} AS DATE)"
-def _str_to_time(self, expression):
+def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -86,20 +88,22 @@ def _str_to_time(self, expression):
return f"CAST({this} AS TIMESTAMP)"
-def _time_format(self, expression):
+def _time_format(
+ self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
+) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Hive.time_format:
return None
return time_format
-def _time_to_str(self, expression):
+def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
-def _to_date_sql(self, expression):
+def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.time_format, Hive.date_format):
@@ -107,7 +111,7 @@ def _to_date_sql(self, expression):
return f"TO_DATE({this})"
-def _unnest_to_explode_sql(self, expression):
+def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
@@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression):
exp.Lateral(
this=udtf(this=expression),
view=True,
- alias=exp.TableAlias(this=alias.this, columns=[column]),
+ alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
)
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
@@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression):
return self.join_sql(expression)
-def _index_sql(self, expression):
+def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
@@ -263,14 +267,15 @@ class Hive(Dialect):
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
@@ -333,13 +338,19 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
- def with_properties(self, properties):
+ def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
+ return self.func(
+ "COLLECT_LIST",
+ expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
+ )
+
+ def with_properties(self, properties: exp.Properties) -> str:
return self.properties(
properties,
prefix=self.seg("TBLPROPERTIES"),
)
- def datatype_sql(self, expression):
+ def datatype_sql(self, expression: exp.DataType) -> str:
if (
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions