diff options
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 59 |
1 files changed, 37 insertions, 22 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 34bc3bd..0829669 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -def _check_int(s): +def _check_int(s: str) -> bool: if s[0] in ("-", "+"): return s[1:].isdigit() return s.isdigit() # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args): +def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time_sql(self, expression): +def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression): # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self): +def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() + + if not this: + return None + self._match(TokenType.COMMA) expression = self._parse_bitwise() @@ -101,7 +105,7 @@ def _parse_date_part(self): scale = None ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix = self.expression(exp.TimeToUnix, this=ts) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) if scale: to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) @@ -112,7 +116,7 @@ def _parse_date_part(self): # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args): +def _div0_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -120,18 +124,18 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args): +def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args): +def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" elif expression.this == exp.DataType.Type.MAP: @@ -155,9 +159,8 @@ class Snowflake(Dialect): "MM": "%m", "mm": "%m", "DD": "%d", - "dd": "%d", - "d": "%-d", - "DY": "%w", + "dd": "%-d", + "DY": "%a", "dy": "%w", "HH24": "%H", "hh24": "%H", @@ -174,6 +177,8 @@ class Snowflake(Dialect): } class Parser(parser.Parser): + QUOTED_PIVOT_COLUMNS = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, @@ -269,9 +274,14 @@ class Snowflake(Dialect): "$": TokenType.PARAMETER, } + VAR_SINGLE_TOKENS = {"$"} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False + SINGLE_STRING_INTERVAL = True + JOIN_HINTS = False + TABLE_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -287,26 +297,30 @@ class Snowflake(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.If: rename_func("IFF"), - exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), - exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.LogicalOr: rename_func("BOOLOR_AGG"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Max: max_or_greatest, + exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TimeToStr: lambda self, e: self.func( + "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) + ), + exp.TimestampTrunc: timestamptrunc_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.Max: max_or_greatest, - exp.Min: min_or_least, + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), } TYPE_MAPPING = { @@ -322,14 +336,15 @@ class Snowflake(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.SetProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") return super().except_op(expression) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) |