From 4cc7d5a6dcda8f275b4156a9a23bbe5380be1b53 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 10 Aug 2023 11:23:50 +0200 Subject: Merging upstream version 17.11.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/presto.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) (limited to 'sqlglot/dialects/presto.py') diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 14ec3dd..291b478 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( Dialect, binary_from_function, date_trunc_to_time, + encode_decode_sql, format_time_lambda, if_sql, left_to_substring_sql, @@ -21,7 +22,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.errors import UnsupportedError from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + expression = expression.copy() return self.sql( exp.Join( this=exp.Unnest( @@ -59,16 +60,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: - _ensure_utf8(expression.args["charset"]) - return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) - - -def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: - _ensure_utf8(expression.args["charset"]) - return f"TO_UTF8({self.sql(expression, 'this')})" - - def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" @@ -106,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") - return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto") + return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): - this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE") + this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE") return self.func( "DATE_ADD", @@ -123,11 +114,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _ensure_utf8(charset: exp.Literal) -> None: - if charset.name.lower() != "utf-8": - raise UnsupportedError(f"Unsupported charset {charset}") - - def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( @@ -288,9 +274,9 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", - exp.Decode: _decode_sql, + exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", - exp.Encode: _encode_sql, + exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), -- cgit v1.2.3