diff options
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r-- | sqlglot/dialects/hive.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0110eee..68137ae 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, locate_to_strposition, + max_or_greatest, min_or_least, no_ilike_sql, no_recursive_cte_sql, @@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = { "DAY": ("DATE_ADD", 1), } +TIME_DIFF_FACTOR = { + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", +} + DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") @@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() + + factor = TIME_DIFF_FACTOR.get(unit) + if factor is not None: + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" + return f"({sec_diff}){factor}" if factor else sec_diff + sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" @@ -237,11 +253,6 @@ class Hive(Dialect): "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": locate_to_strposition, - "LOG": ( - lambda args: exp.Log.from_arg_list(args) - if len(args) > 1 - else exp.Ln.from_arg_list(args) - ), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -261,6 +272,8 @@ class Hive(Dialect): ), } + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -293,6 +306,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: var_map_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, @@ -338,6 +352,8 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } + LIMIT_FETCH = "LIMIT" + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", |