summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 489d439..6133a27 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
- step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
+ step = expression.args.get("step")
target_type = None
@@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)
- return self.func("SEQUENCE", start, end, step)
+ sql = self.func("SEQUENCE", start, end, step)
+ if isinstance(expression.parent, exp.Table):
+ sql = f"UNNEST({sql})"
+
+ return sql
def _ensure_utf8(charset: exp.Literal) -> None:
@@ -204,6 +208,7 @@ class Presto(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
@@ -219,23 +224,23 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
+ "FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
+ ),
"NOW": exp.CurrentTimestamp.from_arg_list,
+ "SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
- "APPROX_PERCENTILE": _approx_percentile,
- "FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
- "FROM_UTF8": lambda args: exp.Decode(
- this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
- ),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
@@ -264,7 +269,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -290,6 +294,7 @@ class Presto(Dialect):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
@@ -303,7 +308,11 @@ class Presto(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_qualify, transforms.explode_to_unnest]
+ [
+ transforms.eliminate_qualify,
+ transforms.eliminate_distinct_on,
+ transforms.explode_to_unnest,
+ ]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
@@ -327,6 +336,9 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.WithinGroup: transforms.preprocess(
+ [transforms.remove_within_group_for_percentiles]
+ ),
}
def interval_sql(self, expression: exp.Interval) -> str: