diff options
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 4a090c2..6413f6d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + min_or_least, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -116,10 +117,16 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull def _zeroifnull_to_if(args): - cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _nullifzero_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) + return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -167,6 +174,11 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), @@ -180,6 +192,7 @@ class Snowflake(Dialect): "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, "ZEROIFNULL": _zeroifnull_to_if, + "NULLIFZERO": _nullifzero_to_if, } FUNCTION_PARSERS = { @@ -254,6 +267,7 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" INTEGER_DIVISION = False + MATCHED_BY_SOURCE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -278,6 +292,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Min: min_or_least, } TYPE_MAPPING = { @@ -343,11 +358,10 @@ class Snowflake(Dialect): expression. This might not be true in a case where the same column name can be sourced from another table that can properly quote but should be true in most cases. """ - values_expressions = expression.find_all(exp.Values) values_identifiers = set( flatten( - v.args.get("alias", exp.Alias()).args.get("columns", []) - for v in values_expressions + (v.args.get("alias") or exp.Alias()).args.get("columns", []) + for v in expression.find_all(exp.Values) ) ) if values_identifiers: |