diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/hive.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r-- | sqlglot/dialects/hive.py | 68 |
1 files changed, 46 insertions, 22 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 3b1c8de..0723e37 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -4,10 +4,13 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( + DATE_ADD_OR_SUB, Dialect, + NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, create_with_partitions_sql, + datestrtodate_sql, format_time_lambda, if_sql, is_parse_json, @@ -76,7 +79,10 @@ def _create_sql(self, expression: exp.Create) -> str: return create_with_partitions_sql(self, expression) -def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: + if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: + return self.func("DATE_ADD", expression.this, expression.expression) + unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -95,7 +101,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) - return self.func(func, expression.this, modified_increment) -def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff | exp.TsOrDsDiff) -> str: unit = expression.text("unit").upper() factor = TIME_DIFF_FACTOR.get(unit) @@ -111,25 +117,31 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" - if months_between: - # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part - diff_sql = f"CAST({diff_sql} AS INT)" + if months_between or multiplier_sql: + # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part. + # For the same reason, we want to truncate if there's a divisor present. + diff_sql = f"CAST({diff_sql}{multiplier_sql} AS INT)" - return f"{diff_sql}{multiplier_sql}" + return diff_sql def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if is_parse_json(this) and this.this.is_string: - # Since FROM_JSON requires a nested type, we always wrap the json string with - # an array to ensure that "naked" strings like "'a'" will be handled correctly - wrapped_json = exp.Literal.string(f"[{this.this.name}]") - from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json)) - to_json = self.func("TO_JSON", from_json) + if is_parse_json(this): + if this.this.is_string: + # Since FROM_JSON requires a nested type, we always wrap the json string with + # an array to ensure that "naked" strings like "'a'" will be handled correctly + wrapped_json = exp.Literal.string(f"[{this.this.name}]") + + from_json = self.func( + "FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json) + ) + to_json = self.func("TO_JSON", from_json) - # This strips the [, ] delimiters of the dummy array printed by TO_JSON - return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") + # This strips the [, ] delimiters of the dummy array printed by TO_JSON + return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") + return self.sql(this) return self.func("TO_JSON", this, expression.args.get("options")) @@ -175,6 +187,8 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): return f"TO_DATE({this}, {time_format})" + if isinstance(expression.this, exp.TsOrDsToDate): + return this return f"TO_DATE({this})" @@ -182,9 +196,10 @@ class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True SUPPORTS_USER_DEFINED_TYPES = False + SAFE_DIVISION = True # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_MAPPING = { "y": "%Y", @@ -241,10 +256,10 @@ class Hive(Dialect): "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, "MSCK REPAIR": TokenType.COMMAND, - "REFRESH": TokenType.COMMAND, - "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, + "REFRESH": TokenType.REFRESH, "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT, "VERSION AS OF": TokenType.VERSION_SNAPSHOT, + "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } NUMERIC_LITERALS = { @@ -264,7 +279,7 @@ class Hive(Dialect): **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, - "COLLECT_SET": exp.SetAgg.from_arg_list, + "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), @@ -411,7 +426,13 @@ class Hive(Dialect): INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False - SUPPORTS_NESTED_CTES = False + + EXPRESSIONS_WITHOUT_NESTED_CTES = { + exp.Insert, + exp.Select, + exp.Subquery, + exp.Union, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -445,7 +466,7 @@ class Hive(Dialect): exp.With: no_recursive_cte_sql, exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, - exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _add_date_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", @@ -477,7 +498,7 @@ class Hive(Dialect): exp.Right: right_to_substring_sql, exp.SafeDivide: no_safe_divide_sql, exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), - exp.SetAgg: rename_func("COLLECT_SET"), + exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -491,7 +512,8 @@ class Hive(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToBase64: rename_func("BASE64"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsAdd: _add_date_sql, + exp.TsOrDsDiff: _date_diff_sql, exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, exp.UnixToStr: lambda self, e: self.func( @@ -571,6 +593,8 @@ class Hive(Dialect): and not expression.expressions ): expression = exp.DataType.build("text") + elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions: + expression.set("this", exp.DataType.Type.VARCHAR) elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) elif expression.is_type("float"): |