summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py94
1 files changed, 54 insertions, 40 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 6133a27..52a04a4 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
no_ilike_sql,
+ no_pivot_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
@@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
-def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
- start = expression.args["start"]
- end = expression.args["end"]
- step = expression.args.get("step")
-
- target_type = None
-
- if isinstance(start, exp.Cast):
- target_type = start.to
- elif isinstance(end, exp.Cast):
- target_type = end.to
-
- if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
- to = target_type.copy()
-
- if target_type is start.to:
- end = exp.Cast(this=end, to=to)
- else:
- start = exp.Cast(this=start, to=to)
-
- sql = self.func("SEQUENCE", start, end, step)
- if isinstance(expression.parent, exp.Table):
- sql = f"UNNEST({sql})"
-
- return sql
-
-
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
-def _approx_percentile(args: t.Sequence) -> exp.Expression:
+def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args: t.Sequence) -> exp.Expression:
+def _from_unixtime(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
+def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Table):
+ if isinstance(expression.this, exp.GenerateSeries):
+ unnest = exp.Unnest(expressions=[expression.this])
+
+ if expression.alias:
+ return exp.alias_(
+ unnest,
+ alias="_u",
+ table=[expression.alias],
+ copy=False,
+ )
+ return unnest
+ return expression
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
- time_format = MySQL.time_format # type: ignore
- time_mapping = MySQL.time_mapping # type: ignore
+ time_format = MySQL.time_format
+ time_mapping = MySQL.time_mapping
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
+ "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
@@ -252,13 +243,13 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -268,8 +259,9 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: _approx_distinct_sql,
+ exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@@ -293,7 +285,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
- exp.GenerateSeries: _sequence_sql,
+ exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
@@ -301,10 +293,10 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
- exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -320,8 +312,7 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
- exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@@ -336,6 +327,7 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
@@ -351,3 +343,25 @@ class Presto(Dialect):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
+
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ start = expression.args["start"]
+ end = expression.args["end"]
+ step = expression.args.get("step")
+
+ if isinstance(start, exp.Cast):
+ target_type = start.to
+ elif isinstance(end, exp.Cast):
+ target_type = end.to
+ else:
+ target_type = None
+
+ if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
+ to = target_type.copy()
+
+ if target_type is start.to:
+ end = exp.Cast(this=end, to=to)
+ else:
+ start = exp.Cast(this=start, to=to)
+
+ return self.func("SEQUENCE", start, end, step)