summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/hive.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r--sqlglot/dialects/hive.py26
1 files changed, 21 insertions, 5 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 0110eee..68137ae 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
locate_to_strposition,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
no_recursive_cte_sql,
@@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = {
"DAY": ("DATE_ADD", 1),
}
+TIME_DIFF_FACTOR = {
+ "MILLISECOND": " * 1000",
+ "SECOND": "",
+ "MINUTE": " / 60",
+ "HOUR": " / 3600",
+}
+
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
@@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
+
+ factor = TIME_DIFF_FACTOR.get(unit)
+ if factor is not None:
+ left = self.sql(expression, "this")
+ right = self.sql(expression, "expression")
+ sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
+ return f"({sec_diff}){factor}" if factor else sec_diff
+
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
@@ -237,11 +253,6 @@ class Hive(Dialect):
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": locate_to_strposition,
- "LOG": (
- lambda args: exp.Log.from_arg_list(args)
- if len(args) > 1
- else exp.Ln.from_arg_list(args)
- ),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@@ -261,6 +272,8 @@ class Hive(Dialect):
),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -293,6 +306,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: var_map_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
@@ -338,6 +352,8 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
+ LIMIT_FETCH = "LIMIT"
+
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",