diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/presto.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 72 |
1 files changed, 54 insertions, 18 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index ded3655..10a6074 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -5,9 +5,11 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, binary_from_function, bool_xor_sql, date_trunc_to_time, + datestrtodate_sql, encode_decode_sql, format_time_lambda, if_sql, @@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, timestamptrunc_sql, timestrtotime_sql, + ts_or_ds_add_cast, ) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import apply_index_offset, seq_get @@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: - this = expression.this + expression = ts_or_ds_add_cast(expression) + unit = exp.Literal.string(expression.text("unit") or "day") + return self.func("DATE_ADD", unit, expression.expression, expression.this) - if not isinstance(this, exp.CurrentDate): - this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE") - return self.func( - "DATE_ADD", - exp.Literal.string(expression.text("unit") or "day"), - expression.expression, - this, - ) +def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: + this = exp.cast(expression.this, "TIMESTAMP") + expr = exp.cast(expression.expression, "TIMESTAMP") + unit = exp.Literal.string(expression.text("unit") or "day") + return self.func("DATE_DIFF", unit, expr, this) def _approx_percentile(args: t.List) -> exp.Expression: @@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) -def _parse_element_at(args: t.List) -> exp.SafeBracket: +def _parse_element_at(args: t.List) -> exp.Bracket: this = seq_get(args, 0) index = seq_get(args, 1) assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) - return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1)) + return exp.Bracket(this=this, expressions=[index], offset=1, safe=True) def _unnest_sequence(expression: exp.Expression) -> exp.Expression: @@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> return rename_func("ARBITRARY")(self, expression) +def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in (None, exp.UnixToTime.SECONDS): + return rename_func("FROM_UNIXTIME")(self, expression) + if scale == exp.UnixToTime.MILLIS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)" + if scale == exp.UnixToTime.MICROS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)" + if scale == exp.UnixToTime.NANOS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)" + + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -175,11 +193,12 @@ class Presto(Dialect): TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True SUPPORTS_SEMI_ANTI_JOIN = False + TYPED_DIVISION = True # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 # https://github.com/prestodb/presto/issues/2863 - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -229,6 +248,7 @@ class Presto(Dialect): ), "ROW": exp.Struct.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, + "SET_AGG": exp.ArrayUniqueAgg.from_arg_list, "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) @@ -253,6 +273,7 @@ class Presto(Dialect): NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") LIMIT_ONLY_LITERALS = True + SUPPORTS_SINGLE_ARG_CONCAT = False PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -284,6 +305,7 @@ class Presto(Dialect): exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), + exp.ArrayUniqueAgg: rename_func("SET_AGG"), exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", @@ -298,7 +320,7 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), - exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", @@ -330,9 +352,6 @@ class Presto(Dialect): exp.Quantile: _quantile_sql, exp.RegexpExtract: regexp_extract_sql, exp.Right: right_to_substring_sql, - exp.SafeBracket: lambda self, e: self.func( - "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) - ), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -361,10 +380,11 @@ class Presto(Dialect): exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsDiff: _ts_or_ds_diff_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.Unhex: rename_func("FROM_HEX"), exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", - exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), @@ -374,8 +394,24 @@ class Presto(Dialect): exp.Xor: bool_xor_sql, } + def bracket_sql(self, expression: exp.Bracket) -> str: + if expression.args.get("safe"): + return self.func( + "ELEMENT_AT", + expression.this, + seq_get( + apply_index_offset( + expression.this, + expression.expressions, + 1 - expression.args.get("offset", 0), + ), + 0, + ), + ) + return super().bracket_sql(expression) + def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions): + if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions): self.unsupported("Struct with key-value definitions is unsupported.") return self.function_fallback_sql(expression) |