diff options
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 94 |
1 files changed, 54 insertions, 40 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6133a27..52a04a4 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, no_ilike_sql, + no_pivot_sql, no_safe_divide_sql, rename_func, struct_extract_sql, @@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: - start = expression.args["start"] - end = expression.args["end"] - step = expression.args.get("step") - - target_type = None - - if isinstance(start, exp.Cast): - target_type = start.to - elif isinstance(end, exp.Cast): - target_type = end.to - - if target_type and target_type.this == exp.DataType.Type.TIMESTAMP: - to = target_type.copy() - - if target_type is start.to: - end = exp.Cast(this=end, to=to) - else: - start = exp.Cast(this=start, to=to) - - sql = self.func("SEQUENCE", start, end, step) - if isinstance(expression.parent, exp.Table): - sql = f"UNNEST({sql})" - - return sql - - def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args: t.Sequence) -> exp.Expression: +def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression: return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args: t.Sequence) -> exp.Expression: +def _from_unixtime(args: t.List) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) +def _unnest_sequence(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Table): + if isinstance(expression.this, exp.GenerateSeries): + unnest = exp.Unnest(expressions=[expression.this]) + + if expression.alias: + return exp.alias_( + unnest, + alias="_u", + table=[expression.alias], + copy=False, + ) + return unnest + return expression + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" - time_format = MySQL.time_format # type: ignore - time_mapping = MySQL.time_mapping # type: ignore + time_format = MySQL.time_format + time_mapping = MySQL.time_mapping class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "START": TokenType.BEGIN, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "ROW": TokenType.STRUCT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "CARDINALITY": exp.ArraySize.from_arg_list, @@ -252,13 +243,13 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -268,8 +259,9 @@ class Presto(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: _approx_distinct_sql, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), @@ -293,7 +285,7 @@ class Presto(Dialect): exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, - exp.GenerateSeries: _sequence_sql, + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, @@ -301,10 +293,10 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.LogicalOr: rename_func("BOOL_OR"), + exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -320,8 +312,7 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", - exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, @@ -336,6 +327,7 @@ class Presto(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), + exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), @@ -351,3 +343,25 @@ class Presto(Dialect): modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" + + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + start = expression.args["start"] + end = expression.args["end"] + step = expression.args.get("step") + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + else: + target_type = None + + if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP): + to = target_type.copy() + + if target_type is start.to: + end = exp.Cast(this=end, to=to) + else: + start = exp.Cast(this=start, to=to) + + return self.func("SEQUENCE", start, end, step) |