diff options
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r-- | sqlglot/dialects/hive.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 3f925a7..7bff553 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, + arg_max_or_min_no_count, create_with_partitions_sql, format_time_lambda, if_sql, @@ -106,11 +107,16 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" return f"({sec_diff}){factor}" if factor else sec_diff - sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" + months_between = unit in DIFF_MONTH_SWITCH + sql_func = "MONTHS_BETWEEN" if months_between else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + if months_between: + # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part + diff_sql = f"CAST({diff_sql} AS INT)" + return f"{diff_sql}{multiplier_sql}" @@ -426,6 +432,8 @@ class Hive(Dialect): exp.Property: _property_sql, exp.AnyValue: rename_func("FIRST"), exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), + exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), |