summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
commit38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch)
tree64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/dialects/presto.py
parentReleasing debian version 20.4.0-1. (diff)
downloadsqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz
sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py40
1 files changed, 16 insertions, 24 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 360ab65..9b421e7 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
+ path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_DIFF", unit, expr, this)
@@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
-def _parse_element_at(args: t.List) -> exp.Bracket:
- this = seq_get(args, 0)
- index = seq_get(args, 1)
- assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
- return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
-
-
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return rename_func("FROM_UNIXTIME")(self, expression)
- if scale == exp.UnixToTime.MILLIS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
- if scale == exp.UnixToTime.MICROS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
- if scale == exp.UnixToTime.NANOS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
def _to_int(expression: exp.Expression) -> exp.Expression:
@@ -215,6 +202,7 @@ class Presto(Dialect):
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@@ -258,7 +246,9 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
- "ELEMENT_AT": _parse_element_at,
+ "ELEMENT_AT": lambda args: exp.Bracket(
+ this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
+ ),
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@@ -344,20 +334,20 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(
e.expression,
),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(e.expression * -1),
e.this,
),
@@ -366,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
+ exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@@ -376,6 +367,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
+ exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -446,7 +438,7 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
+ if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)
@@ -454,8 +446,8 @@ class Presto(Dialect):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
- if expression.this and unit.lower().startswith("week"):
- return f"({expression.this.name} * INTERVAL '7' day)"
+ if expression.this and unit.startswith("WEEK"):
+ return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str: