summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py56
1 files changed, 52 insertions, 4 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 75dc9dc..77b09e9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,13 +3,15 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ datestrtodate_sql,
format_time_lambda,
inline_array_sql,
rename_func,
+ timestrtotime_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import seq_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -183,7 +185,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
- ESCAPES = ["\\"]
+ ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@@ -206,9 +208,10 @@ class Snowflake(Dialect):
CREATE_TRANSIENT = True
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -218,13 +221,14 @@ class Snowflake(Dialect):
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -246,3 +250,47 @@ class Snowflake(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
+
+ def values_sql(self, expression: exp.Values) -> str:
+ """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
+
+ We also want to make sure that after we find matches where we need to unquote a column that we prevent users
+ from adding quotes to the column by using the `identify` argument when generating the SQL.
+ """
+ alias = expression.args.get("alias")
+ if alias and alias.args.get("columns"):
+ expression = expression.transform(
+ lambda node: exp.Identifier(**{**node.args, "quoted": False})
+ if isinstance(node, exp.Identifier)
+ and isinstance(node.parent, exp.TableAlias)
+ and node.arg_key == "columns"
+ else node,
+ )
+ return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
+ return super().values_sql(expression)
+
+ def select_sql(self, expression: exp.Select) -> str:
+ """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
+ that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
+ to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
+ generating the SQL.
+
+ Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
+ expression. This might not be true in a case where the same column name can be sourced from another table that can
+ properly quote but should be true in most cases.
+ """
+ values_expressions = expression.find_all(exp.Values)
+ values_identifiers = set(
+ flatten(
+ v.args.get("alias", exp.Alias()).args.get("columns", [])
+ for v in values_expressions
+ )
+ )
+ if values_identifiers:
+ expression = expression.transform(
+ lambda node: exp.Identifier(**{**node.args, "quoted": False})
+ if isinstance(node, exp.Identifier) and node in values_identifiers
+ else node,
+ )
+ return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
+ return super().select_sql(expression)