summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py62
1 files changed, 46 insertions, 16 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index e16ea1d..a79a9f9 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_safe_divide_sql,
rename_func,
- str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
@@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression):
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _concat_ws_sql(self, expression):
- sep, *args = expression.expressions
- sep = self.sql(sep)
- if len(args) > 1:
- return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
- return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
-
-
def _datatype_sql(self, expression):
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
@@ -61,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
- return f"FROM_UTF8({self.sql(expression, 'this')})"
+ return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
def _encode_sql(self, expression):
@@ -119,6 +110,38 @@ def _ensure_utf8(charset):
raise UnsupportedError(f"Unsupported charset {charset}")
+def _approx_percentile(args):
+ if len(args) == 4:
+ return exp.ApproxQuantile(
+ this=seq_get(args, 0),
+ weight=seq_get(args, 1),
+ quantile=seq_get(args, 2),
+ accuracy=seq_get(args, 3),
+ )
+ if len(args) == 3:
+ return exp.ApproxQuantile(
+ this=seq_get(args, 0),
+ quantile=seq_get(args, 1),
+ accuracy=seq_get(args, 2),
+ )
+ return exp.ApproxQuantile.from_arg_list(args)
+
+
+def _from_unixtime(args):
+ if len(args) == 3:
+ return exp.UnixToTime(
+ this=seq_get(args, 0),
+ hours=seq_get(args, 1),
+ minutes=seq_get(args, 2),
+ )
+ if len(args) == 2:
+ return exp.UnixToTime(
+ this=seq_get(args, 0),
+ zone=seq_get(args, 1),
+ )
+ return exp.UnixToTime.from_arg_list(args)
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -150,19 +173,25 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
- "FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
- "STRPOS": exp.StrPosition.from_arg_list,
+ "FROM_UNIXTIME": _from_unixtime,
+ "STRPOS": lambda args: exp.StrPosition(
+ this=seq_get(args, 0),
+ substr=seq_get(args, 1),
+ instance=seq_get(args, 2),
+ ),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
- this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
+ FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
@@ -194,7 +223,6 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@@ -209,12 +237,13 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
- exp.StrPosition: str_position_sql,
+ exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
@@ -233,6 +262,7 @@ class Presto(Dialect):
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
+ exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):