From beba715b97dd2349e01dde9b077d2535680ebdca Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 10 May 2023 08:44:58 +0200 Subject: Merging upstream version 12.2.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/presto.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'sqlglot/dialects/presto.py') diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 489d439..6133a27 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: start = expression.args["start"] end = expression.args["end"] - step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series + step = expression.args.get("step") target_type = None @@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> else: start = exp.Cast(this=start, to=to) - return self.func("SEQUENCE", start, end, step) + sql = self.func("SEQUENCE", start, end, step) + if isinstance(expression.parent, exp.Table): + sql = f"UNNEST({sql})" + + return sql def _ensure_utf8(charset: exp.Literal) -> None: @@ -204,6 +208,7 @@ class Presto(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, + "APPROX_PERCENTILE": _approx_percentile, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( @@ -219,23 +224,23 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, + "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, + "FROM_UTF8": lambda args: exp.Decode( + this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") + ), "NOW": exp.CurrentTimestamp.from_arg_list, + "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2), ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, - "APPROX_PERCENTILE": _approx_percentile, - "FROM_HEX": exp.Unhex.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, "TO_UTF8": lambda args: exp.Encode( this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), - "FROM_UTF8": lambda args: exp.Decode( - this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") - ), } FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("TRIM") @@ -264,7 +269,6 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -290,6 +294,7 @@ class Presto(Dialect): exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, exp.GenerateSeries: _sequence_sql, + exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, @@ -303,7 +308,11 @@ class Presto(Dialect): exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_qualify, transforms.explode_to_unnest] + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.explode_to_unnest, + ] ), exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), @@ -327,6 +336,9 @@ class Presto(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), + exp.WithinGroup: transforms.preprocess( + [transforms.remove_within_group_for_percentiles] + ), } def interval_sql(self, expression: exp.Interval) -> str: -- cgit v1.2.3