diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:46 +0000 |
commit | 20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch) | |
tree | c000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/dialects/snowflake.py | |
parent | Releasing debian version 12.2.0-1. (diff) | |
download | sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip |
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 111 |
1 files changed, 49 insertions, 62 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 70dcaa9..756e8e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime(this=first_arg, scale=timescale) + from sqlglot.optimizer.simplify import simplify_literals + + # The first argument might be an expression like 40 * 365 * 86400, so we try to + # reduce it using `simplify_literals` first and then check if it's a Literal. first_arg = seq_get(args, 0) - if not isinstance(first_arg, Literal): + if not isinstance(simplify_literals(first_arg, root=True), Literal): # case: <variant_expr> return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime.from_arg_list(args) +def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: + expression = parser.parse_var_map(args) + + if isinstance(expression, exp.StarMap): + return expression + + return exp.Struct( + expressions=[ + t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values) + ] + ) + + def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") @@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.Sequence) -> exp.Expression: +def _div0_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: +def _zeroifnull_to_if(args: t.List) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: +def _nullifzero_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _parse_convert_timezone(args: t.List) -> exp.Expression: + if len(args) == 3: + return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) + return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -177,17 +200,14 @@ class Snowflake(Dialect): } class Parser(parser.Parser): - QUOTED_PIVOT_COLUMNS = True + IDENTIFY_PIVOT_STRINGS = True FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, - "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone( - this=seq_get(args, 1), - zone=seq_get(args, 0), - ), + "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), @@ -202,7 +222,7 @@ class Snowflake(Dialect): "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, - "OBJECT_CONSTRUCT": parser.parse_var_map, + "OBJECT_CONSTRUCT": _parse_object_construct, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_ARRAY": exp.Array.from_arg_list, @@ -224,7 +244,7 @@ class Snowflake(Dialect): } COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, # type: ignore + **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -232,14 +252,16 @@ class Snowflake(Dialect): ), } + TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, # type: ignore + **parser.Parser.ALTER_PARSERS, "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), "SET": lambda self: self._parse_alter_table_set_tag(), } @@ -256,17 +278,20 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "CHAR VARYING": TokenType.VARCHAR, + "CHARACTER VARYING": TokenType.VARCHAR, "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "MINUS": TokenType.EXCEPT, + "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, - "MINUS": TokenType.EXCEPT, "SAMPLE": TokenType.TABLE_SAMPLE, } @@ -285,7 +310,7 @@ class Snowflake(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), @@ -299,6 +324,7 @@ class Snowflake(Dialect): exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Extract: rename_func("DATE_PART"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -312,6 +338,10 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Struct: lambda self, e: self.func( + "OBJECT_CONSTRUCT", + *(arg for expression in e.expressions for arg in expression.flatten()), + ), exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( @@ -326,7 +356,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -336,7 +366,7 @@ class Snowflake(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.SetProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -351,53 +381,10 @@ class Snowflake(Dialect): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) - def values_sql(self, expression: exp.Values) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. - - We also want to make sure that after we find matches where we need to unquote a column that we prevent users - from adding quotes to the column by using the `identify` argument when generating the SQL. - """ - alias = expression.args.get("alias") - if alias and alias.args.get("columns"): - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) - and isinstance(node.parent, exp.TableAlias) - and node.arg_key == "columns" - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) - return super().values_sql(expression) - def settag_sql(self, expression: exp.SetTag) -> str: action = "UNSET" if expression.args.get("unset") else "SET" return f"{action} TAG {self.expressions(expression)}" - def select_sql(self, expression: exp.Select) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also - that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need - to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when - generating the SQL. - - Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the - expression. This might not be true in a case where the same column name can be sourced from another table that can - properly quote but should be true in most cases. - """ - values_identifiers = set( - flatten( - (v.args.get("alias") or exp.Alias()).args.get("columns", []) - for v in expression.find_all(exp.Values) - ) - ) - if values_identifiers: - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) and node in values_identifiers - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) - return super().select_sql(expression) - def describe_sql(self, expression: exp.Describe) -> str: # Default to table if kind is unknown kind_value = expression.args.get("kind") or "TABLE" |