diff options
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 85 |
1 files changed, 56 insertions, 29 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 07e8f43..489d439 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -19,20 +21,20 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self, expression): +def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: sql = self.datatype_sql(expression) if expression.this == exp.DataType.Type.TIMESTAMPTZ: sql = f"{sql} WITH TIME ZONE" return sql -def _explode_to_unnest_sql(self, expression): +def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): return self.sql( exp.Join( @@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression): return self.lateral_sql(expression) -def _initcap_sql(self, expression): +def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: + _ensure_utf8(expression.args["charset"]) return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) -def _encode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: + _ensure_utf8(expression.args["charset"]) return f"TO_UTF8({self.sql(expression, 'this')})" -def _no_sort_array(self, expression): +def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -70,49 +72,62 @@ def _no_sort_array(self, expression): return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self, expression): +def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" - for schema in expression.parent.find_all(exp.Schema): - if isinstance(schema.parent, exp.Property): - expression = expression.copy() - expression.expressions.extend(schema.expressions) + if expression.parent: + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) return self.schema_sql(expression) -def _quantile_sql(self, expression): +def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" -def _str_to_time_sql(self, expression): +def _str_to_time_sql( + self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self, expression): +def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.time_format, Presto.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" -def _ts_or_ds_add_sql(self, expression): +def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: + this = expression.this + + if not isinstance(this, exp.CurrentDate): + this = self.func( + "DATE_PARSE", + self.func( + "SUBSTR", + this if this.is_string else exp.cast(this, "VARCHAR"), + exp.Literal.number(1), + exp.Literal.number(10), + ), + Presto.date_format, + ) + return self.func( "DATE_ADD", exp.Literal.string(expression.text("unit") or "day"), expression.expression, - self.func( - "DATE_PARSE", - self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)), - Presto.date_format, - ), + this, ) -def _sequence_sql(self, expression): +def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: start = expression.args["start"] end = expression.args["end"] step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series @@ -135,12 +150,12 @@ def _sequence_sql(self, expression): return self.func("SEQUENCE", start, end, step) -def _ensure_utf8(charset): +def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args): +def _approx_percentile(args: t.Sequence) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -157,7 +172,7 @@ def _approx_percentile(args): return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args): +def _from_unixtime(args: t.Sequence) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -226,11 +241,15 @@ class Presto(Dialect): FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + TABLE_HINTS = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { @@ -246,7 +265,6 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -284,6 +302,9 @@ class Presto(Dialect): exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.explode_to_unnest] + ), exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", @@ -308,7 +329,13 @@ class Presto(Dialect): exp.VariancePop: rename_func("VAR_POP"), } - def transaction_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: + unit = self.sql(expression, "unit") + if expression.this and unit.lower().startswith("week"): + return f"({expression.this.name} * INTERVAL '7' day)" + return super().interval_sql(expression) + + def transaction_sql(self, expression: exp.Transaction) -> str: modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" |