summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py85
1 files changed, 56 insertions, 29 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 07e8f43..489d439 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _approx_distinct_sql(self, expression):
+def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
-def _explode_to_unnest_sql(self, expression):
+def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
@@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
return self.lateral_sql(expression)
-def _initcap_sql(self, expression):
+def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _decode_sql(self, expression):
- _ensure_utf8(expression.args.get("charset"))
+def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
+ _ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
-def _encode_sql(self, expression):
- _ensure_utf8(expression.args.get("charset"))
+def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
+ _ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"
-def _no_sort_array(self, expression):
+def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
return self.func("ARRAY_SORT", expression.this, comparator)
-def _schema_sql(self, expression):
+def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
- for schema in expression.parent.find_all(exp.Schema):
- if isinstance(schema.parent, exp.Property):
- expression = expression.copy()
- expression.expressions.extend(schema.expressions)
+ if expression.parent:
+ for schema in expression.parent.find_all(exp.Schema):
+ if isinstance(schema.parent, exp.Property):
+ expression = expression.copy()
+ expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
-def _quantile_sql(self, expression):
+def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
-def _str_to_time_sql(self, expression):
+def _str_to_time_sql(
+ self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
-def _ts_or_ds_to_date_sql(self, expression):
+def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
-def _ts_or_ds_add_sql(self, expression):
+def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+ this = expression.this
+
+ if not isinstance(this, exp.CurrentDate):
+ this = self.func(
+ "DATE_PARSE",
+ self.func(
+ "SUBSTR",
+ this if this.is_string else exp.cast(this, "VARCHAR"),
+ exp.Literal.number(1),
+ exp.Literal.number(10),
+ ),
+ Presto.date_format,
+ )
+
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
- self.func(
- "DATE_PARSE",
- self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
- Presto.date_format,
- ),
+ this,
)
-def _sequence_sql(self, expression):
+def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
@@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
return self.func("SEQUENCE", start, end, step)
-def _ensure_utf8(charset):
+def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
-def _approx_percentile(args):
+def _approx_percentile(args: t.Sequence) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -157,7 +172,7 @@ def _approx_percentile(args):
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args):
+def _from_unixtime(args: t.Sequence) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -226,11 +241,15 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
+ INTERVAL_ALLOWS_PLURAL_FORM = False
+ JOIN_HINTS = False
+ TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
@@ -246,7 +265,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -284,6 +302,9 @@ class Presto(Dialect):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_qualify, transforms.explode_to_unnest]
+ ),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
@@ -308,7 +329,13 @@ class Presto(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
- def transaction_sql(self, expression):
+ def interval_sql(self, expression: exp.Interval) -> str:
+ unit = self.sql(expression, "unit")
+ if expression.this and unit.lower().startswith("week"):
+ return f"({expression.this.name} * INTERVAL '7' day)"
+ return super().interval_sql(expression)
+
+ def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"