diff options
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 54 |
1 files changed, 45 insertions, 9 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 078da0b..4b54e95 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: +def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: +def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): expression = expression.copy() return self.sql( @@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) - return self.lateral_sql(expression) -def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: +def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: +def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: +def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" @@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: return self.schema_sql(expression) -def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: +def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" def _str_to_time_sql( - self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate + self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: +def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") -def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: +def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): @@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: return expression +def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str: + """ + Trino doesn't support FIRST / LAST as functions, but they're valid in the context + of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases + they're converted into an ARBITRARY call. + + Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions + """ + if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize): + return self.function_fallback_sql(expression) + + return rename_func("ARBITRARY")(self, expression) + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -178,6 +192,7 @@ class Presto(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "ARBITRARY": exp.AnyValue.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "BITWISE_AND": binary_from_function(exp.BitwiseAnd), @@ -205,7 +220,14 @@ class Presto(Dialect): "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) ), + "REGEXP_REPLACE": lambda args: exp.RegexpReplace( + this=seq_get(args, 0), + expression=seq_get(args, 1), + replacement=seq_get(args, 2) or exp.Literal.string(""), + ), + "ROW": exp.Struct.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, + "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), @@ -225,6 +247,7 @@ class Presto(Dialect): QUERY_HINTS = False IS_BOOL_ALLOWED = False TZ_TO_WITH_TIME_ZONE = True + NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -242,10 +265,13 @@ class Presto(Dialect): exp.DataType.Type.TIMETZ: "TIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.STRUCT: "ROW", + exp.DataType.Type.DATETIME: "TIMESTAMP", + exp.DataType.Type.DATETIME64: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", @@ -268,15 +294,23 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", + exp.DateSub: lambda self, e: self.func( + "DATE_ADD", + exp.Literal.string(e.text("unit") or "day"), + e.expression * -1, + e.this, + ), exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, + exp.Last: _first_last_sql, exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -301,8 +335,10 @@ class Presto(Dialect): exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: rename_func("ROW"), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, |