summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py111
1 files changed, 49 insertions, 62 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 70dcaa9..756e8e9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import flatten, seq_get
+from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
+def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime(this=first_arg, scale=timescale)
+ from sqlglot.optimizer.simplify import simplify_literals
+
+ # The first argument might be an expression like 40 * 365 * 86400, so we try to
+ # reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
- if not isinstance(first_arg, Literal):
+ if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime.from_arg_list(args)
+def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
+ expression = parser.parse_var_map(args)
+
+ if isinstance(expression, exp.StarMap):
+ return expression
+
+ return exp.Struct(
+ expressions=[
+ t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
+ ]
+ )
+
+
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
@@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.Sequence) -> exp.Expression:
+def _div0_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
+def _zeroifnull_to_if(args: t.List) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
+def _nullifzero_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _parse_convert_timezone(args: t.List) -> exp.Expression:
+ if len(args) == 3:
+ return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
+ return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
+
+
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@@ -177,17 +200,14 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
- QUOTED_PIVOT_COLUMNS = True
+ IDENTIFY_PIVOT_STRINGS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
- "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
- this=seq_get(args, 1),
- zone=seq_get(args, 0),
- ),
+ "CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
@@ -202,7 +222,7 @@ class Snowflake(Dialect):
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
- "OBJECT_CONSTRUCT": parser.parse_var_map,
+ "OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_ARRAY": exp.Array.from_arg_list,
@@ -224,7 +244,7 @@ class Snowflake(Dialect):
}
COLUMN_OPERATORS = {
- **parser.Parser.COLUMN_OPERATORS, # type: ignore
+ **parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@@ -232,14 +252,16 @@ class Snowflake(Dialect):
),
}
+ TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
+
RANGE_PARSERS = {
- **parser.Parser.RANGE_PARSERS, # type: ignore
+ **parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
- **parser.Parser.ALTER_PARSERS, # type: ignore
+ **parser.Parser.ALTER_PARSERS,
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
}
@@ -256,17 +278,20 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "CHAR VARYING": TokenType.VARCHAR,
+ "CHARACTER VARYING": TokenType.VARCHAR,
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
+ "MINUS": TokenType.EXCEPT,
+ "NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
- "MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
@@ -285,7 +310,7 @@ class Snowflake(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@@ -299,6 +324,7 @@ class Snowflake(Dialect):
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Extract: rename_func("DATE_PART"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -312,6 +338,10 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Struct: lambda self, e: self.func(
+ "OBJECT_CONSTRUCT",
+ *(arg for expression in e.expressions for arg in expression.flatten()),
+ ),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.TimeToStr: lambda self, e: self.func(
@@ -326,7 +356,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -336,7 +366,7 @@ class Snowflake(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -351,53 +381,10 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
- def values_sql(self, expression: exp.Values) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
-
- We also want to make sure that after we find matches where we need to unquote a column that we prevent users
- from adding quotes to the column by using the `identify` argument when generating the SQL.
- """
- alias = expression.args.get("alias")
- if alias and alias.args.get("columns"):
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier)
- and isinstance(node.parent, exp.TableAlias)
- and node.arg_key == "columns"
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
- return super().values_sql(expression)
-
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
- def select_sql(self, expression: exp.Select) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
- that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
- to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
- generating the SQL.
-
- Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
- expression. This might not be true in a case where the same column name can be sourced from another table that can
- properly quote but should be true in most cases.
- """
- values_identifiers = set(
- flatten(
- (v.args.get("alias") or exp.Alias()).args.get("columns", [])
- for v in expression.find_all(exp.Values)
- )
- )
- if values_identifiers:
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier) and node in values_identifiers
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
- return super().select_sql(expression)
-
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"