summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py38
1 files changed, 36 insertions, 2 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 11ea778..9d5cc11 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
+from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
+def _decode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"FROM_UTF8({self.sql(expression, 'this')})"
+
+
+def _encode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"TO_UTF8({self.sql(expression, 'this')})"
+
+
def _no_sort_array(self, expression):
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
- columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
+ columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+def _ensure_utf8(charset):
+ if charset.name.lower() != "utf-8":
+ raise UnsupportedError(f"Unsupported charset {charset}")
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "FROM_HEX": exp.Unhex.from_arg_list,
+ "TO_HEX": exp.Hex.from_arg_list,
+ "TO_UTF8": lambda args: exp.Encode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
}
class Generator(generator.Generator):
@@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.Encode: _encode_sql,
+ exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
+
+ def transaction_sql(self, expression):
+ modes = expression.args.get("modes")
+ modes = f" {', '.join(modes)}" if modes else ""
+ return f"START TRANSACTION{modes}"