diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
commit | 38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch) | |
tree | 64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/dialects/presto.py | |
parent | Releasing debian version 20.4.0-1. (diff) | |
download | sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip |
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 40 |
1 files changed, 16 insertions, 24 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 360ab65..9b421e7 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_safe_divide_sql, no_timestamp_sql, + path_to_jsonpath, regexp_extract_sql, rename_func, right_to_substring_sql, @@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: expression = ts_or_ds_add_cast(expression) - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_ADD", unit, expression.expression, expression.this) def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: this = exp.cast(expression.this, "TIMESTAMP") expr = exp.cast(expression.expression, "TIMESTAMP") - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_DIFF", unit, expr, this) @@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) -def _parse_element_at(args: t.List) -> exp.Bracket: - this = seq_get(args, 0) - index = seq_get(args, 1) - assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) - return exp.Bracket(this=this, expressions=[index], offset=1, safe=True) - - def _unnest_sequence(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Table): if isinstance(expression.this, exp.GenerateSeries): @@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str timestamp = self.sql(expression, "this") if scale in (None, exp.UnixToTime.SECONDS): return rename_func("FROM_UNIXTIME")(self, expression) - if scale == exp.UnixToTime.MILLIS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)" - if scale == exp.UnixToTime.MICROS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)" - if scale == exp.UnixToTime.NANOS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))" def _to_int(expression: exp.Expression) -> exp.Expression: @@ -215,6 +202,7 @@ class Presto(Dialect): STRICT_STRING_CONCAT = True SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True + TABLESAMPLE_SIZE_IS_PERCENT = True # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -258,7 +246,9 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, - "ELEMENT_AT": _parse_element_at, + "ELEMENT_AT": lambda args: exp.Bracket( + this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True + ), "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, "FROM_UTF8": lambda args: exp.Decode( @@ -344,20 +334,20 @@ class Presto(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int( e.expression, ), e.this, ), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int(e.expression * -1), e.this, ), @@ -366,6 +356,7 @@ class Presto(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, + exp.GetPath: path_to_jsonpath(), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") @@ -376,6 +367,7 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), exp.Last: _first_last_sql, + exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -446,7 +438,7 @@ class Presto(Dialect): return super().bracket_sql(expression) def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions): + if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions): self.unsupported("Struct with key-value definitions is unsupported.") return self.function_fallback_sql(expression) @@ -454,8 +446,8 @@ class Presto(Dialect): def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") - if expression.this and unit.lower().startswith("week"): - return f"({expression.this.name} * INTERVAL '7' day)" + if expression.this and unit.startswith("WEEK"): + return f"({expression.this.name} * INTERVAL '7' DAY)" return super().interval_sql(expression) def transaction_sql(self, expression: exp.Transaction) -> str: |