diff options
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 75dc9dc..77b09e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,13 +3,15 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + datestrtodate_sql, format_time_lambda, inline_array_sql, rename_func, + timestrtotime_sql, var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import seq_get +from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -183,7 +185,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPES = ["\\"] + ESCAPES = ["\\", "'"] SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, @@ -206,9 +208,10 @@ class Snowflake(Dialect): CREATE_TRANSIENT = True TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -218,13 +221,14 @@ class Snowflake(Dialect): exp.Matches: rename_func("DECODE"), exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -246,3 +250,47 @@ class Snowflake(Dialect): if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) + + def values_sql(self, expression: exp.Values) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. + + We also want to make sure that after we find matches where we need to unquote a column that we prevent users + from adding quotes to the column by using the `identify` argument when generating the SQL. + """ + alias = expression.args.get("alias") + if alias and alias.args.get("columns"): + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) + and isinstance(node.parent, exp.TableAlias) + and node.arg_key == "columns" + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) + return super().values_sql(expression) + + def select_sql(self, expression: exp.Select) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also + that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need + to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when + generating the SQL. + + Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the + expression. This might not be true in a case where the same column name can be sourced from another table that can + properly quote but should be true in most cases. + """ + values_expressions = expression.find_all(exp.Values) + values_identifiers = set( + flatten( + v.args.get("alias", exp.Alias()).args.get("columns", []) + for v in values_expressions + ) + ) + if values_identifiers: + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) and node in values_identifiers + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) + return super().select_sql(expression) |