summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py26
1 files changed, 6 insertions, 20 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 14ec3dd..291b478 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
Dialect,
binary_from_function,
date_trunc_to_time,
+ encode_decode_sql,
format_time_lambda,
if_sql,
left_to_substring_sql,
@@ -21,7 +22,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
-from sqlglot.errors import UnsupportedError
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
+ expression = expression.copy()
return self.sql(
exp.Join(
this=exp.Unnest(
@@ -59,16 +60,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
- _ensure_utf8(expression.args["charset"])
- return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
-
-
-def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
- _ensure_utf8(expression.args["charset"])
- return f"TO_UTF8({self.sql(expression, 'this')})"
-
-
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
@@ -106,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
- return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto")
+ return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
- this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")
+ this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
return self.func(
"DATE_ADD",
@@ -123,11 +114,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
-def _ensure_utf8(charset: exp.Literal) -> None:
- if charset.name.lower() != "utf-8":
- raise UnsupportedError(f"Unsupported charset {charset}")
-
-
def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
@@ -288,9 +274,9 @@ class Presto(Dialect):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
- exp.Decode: _decode_sql,
+ exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
- exp.Encode: _encode_sql,
+ exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),