diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/hive.py | 45 |
1 files changed, 33 insertions, 12 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6746fcf..871a180 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"{diff_sql}{multiplier_sql}" -def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str: +def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: + this = expression.this + + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(this) + + if this.type.is_type(exp.DataType.Type.JSON): + return self.sql(this) + return self.func("TO_JSON", this, expression.args.get("options")) + + +def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" @@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str: +def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) -def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: return f"CAST({this} AS DATE)" -def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str: +def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -214,6 +227,7 @@ class Hive(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), @@ -251,6 +265,7 @@ class Hive(Dialect): "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, + "UNBASE64": exp.FromBase64.from_arg_list, "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -280,16 +295,20 @@ class Hive(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.UNALIAS_GROUP, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Select: transforms.preprocess( - [transforms.eliminate_qualify, transforms.unnest_to_explode] + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.unnest_to_explode, + ] ), exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), + exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), - exp.ArraySort: _array_sort, + exp.ArraySort: _array_sort_sql, exp.With: no_recursive_cte_sql, exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, @@ -298,12 +317,13 @@ class Hive(Dialect): exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", + exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), - exp.JSONFormat: rename_func("TO_JSON"), + exp.JSONFormat: _json_format_sql, exp.Map: var_map_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, @@ -318,9 +338,9 @@ class Hive(Dialect): exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: strposition_to_locate_sql, - exp.StrToDate: _str_to_date, - exp.StrToTime: _str_to_time, - exp.StrToUnix: _str_to_unix, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: _str_to_unix_sql, exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), @@ -328,6 +348,7 @@ class Hive(Dialect): exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: _time_to_str, exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.ToBase64: rename_func("BASE64"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsToDate: _to_date_sql, |