summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 4a090c2..6413f6d 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
+ min_or_least,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@@ -116,10 +117,16 @@ def _div0_to_if(args):
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args):
- cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null())
+ cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
+# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
+def _nullifzero_to_if(args):
+ cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
+ return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
+
+
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
@@ -167,6 +174,11 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
+ "DATEADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ ),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
@@ -180,6 +192,7 @@ class Snowflake(Dialect):
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
"ZEROIFNULL": _zeroifnull_to_if,
+ "NULLIFZERO": _nullifzero_to_if,
}
FUNCTION_PARSERS = {
@@ -254,6 +267,7 @@ class Snowflake(Dialect):
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
INTEGER_DIVISION = False
+ MATCHED_BY_SOURCE = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -278,6 +292,7 @@ class Snowflake(Dialect):
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Min: min_or_least,
}
TYPE_MAPPING = {
@@ -343,11 +358,10 @@ class Snowflake(Dialect):
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
- values_expressions = expression.find_all(exp.Values)
values_identifiers = set(
flatten(
- v.args.get("alias", exp.Alias()).args.get("columns", [])
- for v in values_expressions
+ (v.args.get("alias") or exp.Alias()).args.get("columns", [])
+ for v in expression.find_all(exp.Values)
)
)
if values_identifiers: