From ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 7 Mar 2023 19:09:31 +0100 Subject: Merging upstream version 11.3.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/snowflake.py | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) (limited to 'sqlglot/dialects/snowflake.py') diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5931364..4a090c2 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -104,6 +106,20 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +# https://docs.snowflake.com/en/sql-reference/functions/div0 +def _div0_to_if(args): + cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) + true = exp.Literal.number(0) + false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1)) + return exp.If(this=cond, true=true, false=false) + + +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _zeroifnull_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -150,16 +166,20 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), ), + "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "TO_ARRAY": exp.Array.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, + "ZEROIFNULL": _zeroifnull_to_if, } FUNCTION_PARSERS = { @@ -193,6 +213,19 @@ class Snowflake(Dialect): ), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, # type: ignore + "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), + "SET": lambda self: self._parse_alter_table_set_tag(), + } + + INTEGER_DIVISION = False + + def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: + self._match_text_seq("TAG") + parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) + return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset) + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] @@ -220,12 +253,14 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" + INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), - exp.DateAdd: rename_func("DATEADD"), + exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), + exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -294,6 +329,10 @@ class Snowflake(Dialect): return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) return super().values_sql(expression) + def settag_sql(self, expression: exp.SetTag) -> str: + action = "UNSET" if expression.args.get("unset") else "SET" + return f"{action} TAG {self.expressions(expression)}" + def select_sql(self, expression: exp.Select) -> str: """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need -- cgit v1.2.3