diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:35 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:35 +0000 |
commit | 2272764864555f26095563937e06a3389d42d789 (patch) | |
tree | 9dc37b7bff42ec0343028e5ecfb0aa147c5d3279 /sqlglot/dialects/presto.py | |
parent | Adding upstream version 10.0.1. (diff) | |
download | sqlglot-2272764864555f26095563937e06a3389d42d789.tar.xz sqlglot-2272764864555f26095563937e06a3389d42d789.zip |
Adding upstream version 10.0.8.upstream/10.0.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 38 |
1 files changed, 36 insertions, 2 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 11ea778..9d5cc11 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL +from sqlglot.errors import UnsupportedError from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -61,8 +62,18 @@ def _initcap_sql(self, expression): return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" +def _decode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"FROM_UTF8({self.sql(expression, 'this')})" + + +def _encode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"TO_UTF8({self.sql(expression, 'this')})" + + def _no_sort_array(self, expression): - if expression.args.get("asc") == exp.FALSE: + if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None @@ -72,7 +83,7 @@ def _no_sort_array(self, expression): def _schema_sql(self, expression): if isinstance(expression.parent, exp.Property): - columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" for schema in expression.parent.find_all(exp.Schema): @@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression): return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" +def _ensure_utf8(charset): + if charset.name.lower() != "utf-8": + raise UnsupportedError(f"Unsupported charset {charset}") + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" @@ -115,6 +131,7 @@ class Presto(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "ROW": TokenType.STRUCT, } @@ -140,6 +157,14 @@ class Presto(Dialect): "STRPOS": exp.StrPosition.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "FROM_HEX": exp.Unhex.from_arg_list, + "TO_HEX": exp.Hex.from_arg_list, + "TO_UTF8": lambda args: exp.Encode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), + "FROM_UTF8": lambda args: exp.Decode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), } class Generator(generator.Generator): @@ -187,7 +212,10 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.Encode: _encode_sql, + exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, @@ -212,7 +240,13 @@ class Presto(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Unhex: rename_func("FROM_HEX"), exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", } + + def transaction_sql(self, expression): + modes = expression.args.get("modes") + modes = f" {', '.join(modes)}" if modes else "" + return f"START TRANSACTION{modes}" |