summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py41
1 files changed, 40 insertions, 1 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 5931364..4a090c2 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@@ -104,6 +106,20 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
+# https://docs.snowflake.com/en/sql-reference/functions/div0
+def _div0_to_if(args):
+ cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
+ true = exp.Literal.number(0)
+ false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1))
+ return exp.If(this=cond, true=true, false=false)
+
+
+# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
+def _zeroifnull_to_if(args):
+ cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null())
+ return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
+
+
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
@@ -150,16 +166,20 @@ class Snowflake(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
+ "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
+ "DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
+ "TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
+ "ZEROIFNULL": _zeroifnull_to_if,
}
FUNCTION_PARSERS = {
@@ -193,6 +213,19 @@ class Snowflake(Dialect):
),
}
+ ALTER_PARSERS = {
+ **parser.Parser.ALTER_PARSERS, # type: ignore
+ "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
+ "SET": lambda self: self._parse_alter_table_set_tag(),
+ }
+
+ INTEGER_DIVISION = False
+
+ def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
+ self._match_text_seq("TAG")
+ parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
+ return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
+
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
@@ -220,12 +253,14 @@ class Snowflake(Dialect):
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
+ INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
- exp.DateAdd: rename_func("DATEADD"),
+ exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
+ exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@@ -294,6 +329,10 @@ class Snowflake(Dialect):
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
+ def settag_sql(self, expression: exp.SetTag) -> str:
+ action = "UNSET" if expression.args.get("unset") else "SET"
+ return f"{action} TAG {self.expressions(expression)}"
+
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need