summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/snowflake.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py115
1 files changed, 89 insertions, 26 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 01f7512..cdbc071 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,9 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
+from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
binary_from_function,
+ date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
@@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
-from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
elif second_arg.name == "3":
timescale = exp.UnixToTime.MILLIS
elif second_arg.name == "9":
- timescale = exp.UnixToTime.MICROS
+ timescale = exp.UnixToTime.NANOS
return exp.UnixToTime(this=first_arg, scale=timescale)
@@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
- if scale in [None, exp.UnixToTime.SECONDS]:
+ if scale in (None, exp.UnixToTime.SECONDS):
return f"TO_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TO_TIMESTAMP({timestamp}, 3)"
if scale == exp.UnixToTime.MICROS:
+ return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
+ if scale == exp.UnixToTime.NANOS:
return f"TO_TIMESTAMP({timestamp}, 9)"
- raise ValueError("Improper scale for timestamp")
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
+ NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
@@ -236,6 +241,18 @@ class Snowflake(Dialect):
"ff6": "%f",
}
+ def quote_identifier(self, expression: E, identify: bool = True) -> E:
+ # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an
+ # unquoted DUAL keyword in a special way and does not map it to a user-defined table
+ if (
+ isinstance(expression, exp.Identifier)
+ and isinstance(expression.parent, exp.Table)
+ and expression.name.lower() == "dual"
+ ):
+ return t.cast(E, expression)
+
+ return super().quote_identifier(expression, identify=identify)
+
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
@@ -245,6 +262,9 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
+ "ARRAY_CONTAINS": lambda args: exp.ArrayContains(
+ this=seq_get(args, 1), expression=seq_get(args, 0)
+ ),
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
start=seq_get(args, 0),
@@ -296,8 +316,8 @@ class Snowflake(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
- TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
- TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
+ TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
+ TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
@@ -317,6 +337,11 @@ class Snowflake(Dialect):
TokenType.SHOW: lambda self: self._parse_show(),
}
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "LOCATION": lambda self: self._parse_location(),
+ }
+
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
@@ -349,7 +374,7 @@ class Snowflake(Dialect):
table: t.Optional[exp.Expression] = None
if self._match_text_seq("@"):
table_name = "@"
- while True:
+ while self._curr:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
@@ -411,6 +436,20 @@ class Snowflake(Dialect):
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
+ def _parse_location(self) -> exp.LocationProperty:
+ self._match(TokenType.EQ)
+
+ parts = [self._parse_var(any_token=True)]
+
+ while self._match(TokenType.SLASH):
+ if self._curr and self._prev.end + 1 == self._curr.start:
+ parts.append(self._parse_var(any_token=True))
+ else:
+ parts.append(exp.Var(this=""))
+ return self.expression(
+ exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
+ )
+
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@@ -457,6 +496,7 @@ class Snowflake(Dialect):
AGGREGATE_FILTER_SUPPORTED = False
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
+ LIMIT_ONLY_LITERALS = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -464,15 +504,14 @@ class Snowflake(Dialect):
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
- exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
- exp.DateDiff: lambda self, e: self.func(
- "DATEDIFF", e.text("unit"), e.expression, e.this
- ),
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -501,10 +540,11 @@ class Snowflake(Dialect):
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
- transforms.explode_to_unnest(0),
+ transforms.explode_to_unnest(),
transforms.eliminate_semi_and_anti_joins,
]
),
+ exp.SHA: rename_func("SHA1"),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
exp.StrPosition: lambda self, e: self.func(
@@ -524,6 +564,8 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
+ exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
+ exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -547,6 +589,20 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def trycast_sql(self, expression: exp.TryCast) -> str:
+ value = expression.this
+
+ if value.type is None:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ value = annotate_types(value)
+
+ if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN):
+ return super().trycast_sql(expression)
+
+ # TRY_CAST only works for string values in Snowflake
+ return self.cast_sql(expression)
+
def log_sql(self, expression: exp.Log) -> str:
if not expression.expression:
return self.func("LN", expression.this)
@@ -554,24 +610,28 @@ class Snowflake(Dialect):
return super().log_sql(expression)
def unnest_sql(self, expression: exp.Unnest) -> str:
- selects = ["value"]
unnest_alias = expression.args.get("alias")
-
offset = expression.args.get("offset")
- if offset:
- if unnest_alias:
- unnest_alias.append("columns", offset.pop())
-
- selects.append("index")
- subquery = exp.Subquery(
- this=exp.select(*selects).from_(
- f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
- ),
- )
+ columns = [
+ exp.to_identifier("seq"),
+ exp.to_identifier("key"),
+ exp.to_identifier("path"),
+ offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"),
+ seq_get(unnest_alias.columns if unnest_alias else [], 0)
+ or exp.to_identifier("value"),
+ exp.to_identifier("this"),
+ ]
+
+ if unnest_alias:
+ unnest_alias.set("columns", columns)
+ else:
+ unnest_alias = exp.TableAlias(this="_u", columns=columns)
+
+ explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
alias = self.sql(unnest_alias)
alias = f" AS {alias}" if alias else ""
- return f"{self.sql(subquery)}{alias}"
+ return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
scope = self.sql(expression, "scope")
@@ -632,3 +692,6 @@ class Snowflake(Dialect):
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")