summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/presto.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py72
1 files changed, 54 insertions, 18 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index ded3655..10a6074 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -5,9 +5,11 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
binary_from_function,
bool_xor_sql,
date_trunc_to_time,
+ datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
if_sql,
@@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
+ ts_or_ds_add_cast,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
@@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
- this = expression.this
+ expression = ts_or_ds_add_cast(expression)
+ unit = exp.Literal.string(expression.text("unit") or "day")
+ return self.func("DATE_ADD", unit, expression.expression, expression.this)
- if not isinstance(this, exp.CurrentDate):
- this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
- return self.func(
- "DATE_ADD",
- exp.Literal.string(expression.text("unit") or "day"),
- expression.expression,
- this,
- )
+def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
+ this = exp.cast(expression.this, "TIMESTAMP")
+ expr = exp.cast(expression.expression, "TIMESTAMP")
+ unit = exp.Literal.string(expression.text("unit") or "day")
+ return self.func("DATE_DIFF", unit, expr, this)
def _approx_percentile(args: t.List) -> exp.Expression:
@@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
-def _parse_element_at(args: t.List) -> exp.SafeBracket:
+def _parse_element_at(args: t.List) -> exp.Bracket:
this = seq_get(args, 0)
index = seq_get(args, 1)
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
- return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
+ return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
@@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) ->
return rename_func("ARBITRARY")(self, expression)
+def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale in (None, exp.UnixToTime.SECONDS):
+ return rename_func("FROM_UNIXTIME")(self, expression)
+ if scale == exp.UnixToTime.MILLIS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
+ if scale == exp.UnixToTime.MICROS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
+ if scale == exp.UnixToTime.NANOS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
+
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
+
+
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@@ -175,11 +193,12 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
+ TYPED_DIVISION = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@@ -229,6 +248,7 @@ class Presto(Dialect):
),
"ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
+ "SET_AGG": exp.ArrayUniqueAgg.from_arg_list,
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
@@ -253,6 +273,7 @@ class Presto(Dialect):
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
+ SUPPORTS_SINGLE_ARG_CONCAT = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -284,6 +305,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
+ exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
@@ -298,7 +320,7 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
- exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
@@ -330,9 +352,6 @@ class Presto(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.Right: right_to_substring_sql,
- exp.SafeBracket: lambda self, e: self.func(
- "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
- ),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -361,10 +380,11 @@ class Presto(Dialect):
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
+ exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
- exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
@@ -374,8 +394,24 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ if expression.args.get("safe"):
+ return self.func(
+ "ELEMENT_AT",
+ expression.this,
+ seq_get(
+ apply_index_offset(
+ expression.this,
+ expression.expressions,
+ 1 - expression.args.get("offset", 0),
+ ),
+ 0,
+ ),
+ )
+ return super().bracket_sql(expression)
+
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions):
+ if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)