summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py59
1 files changed, 37 insertions, 22 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 34bc3bd..0829669 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
-def _check_int(s):
+def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args):
+def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
-def _unix_to_time_sql(self, expression):
+def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self):
+def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
+
+ if not this:
+ return None
+
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
@@ -101,7 +105,7 @@ def _parse_date_part(self):
scale = None
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
- to_unix = self.expression(exp.TimeToUnix, this=ts)
+ to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
@@ -112,7 +116,7 @@ def _parse_date_part(self):
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args):
+def _div0_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -120,18 +124,18 @@ def _div0_to_if(args):
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args):
+def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args):
+def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
@@ -155,9 +159,8 @@ class Snowflake(Dialect):
"MM": "%m",
"mm": "%m",
"DD": "%d",
- "dd": "%d",
- "d": "%-d",
- "DY": "%w",
+ "dd": "%-d",
+ "DY": "%a",
"dy": "%w",
"HH24": "%H",
"hh24": "%H",
@@ -174,6 +177,8 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
+ QUOTED_PIVOT_COLUMNS = True
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@@ -269,9 +274,14 @@ class Snowflake(Dialect):
"$": TokenType.PARAMETER,
}
+ VAR_SINGLE_TOKENS = {"$"}
+
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
+ SINGLE_STRING_INTERVAL = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -287,26 +297,30 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.If: rename_func("IFF"),
- exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
- exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
- exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.LogicalOr: rename_func("BOOLOR_AGG"),
+ exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.Max: max_or_greatest,
+ exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
- exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
+ exp.TimeToStr: lambda self, e: self.func(
+ "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
+ ),
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.Max: max_or_greatest,
- exp.Min: min_or_least,
+ exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
}
TYPE_MAPPING = {
@@ -322,14 +336,15 @@ class Snowflake(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)