diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/snowflake.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 115 |
1 files changed, 89 insertions, 26 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 01f7512..cdbc071 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,9 +3,12 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, binary_from_function, + date_delta_sql, date_trunc_to_time, datestrtodate_sql, format_time_lambda, @@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.expressions import Literal from sqlglot.helper import seq_get -from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, elif second_arg.name == "3": timescale = exp.UnixToTime.MILLIS elif second_arg.name == "9": - timescale = exp.UnixToTime.MICROS + timescale = exp.UnixToTime.NANOS return exp.UnixToTime(this=first_arg, scale=timescale) @@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff: def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") - if scale in [None, exp.UnixToTime.SECONDS]: + if scale in (None, exp.UnixToTime.SECONDS): return f"TO_TIMESTAMP({timestamp})" if scale == exp.UnixToTime.MILLIS: return f"TO_TIMESTAMP({timestamp}, 3)" if scale == exp.UnixToTime.MICROS: + return f"TO_TIMESTAMP({timestamp} / 1000, 3)" + if scale == exp.UnixToTime.NANOS: return f"TO_TIMESTAMP({timestamp}, 9)" - raise ValueError("Improper scale for timestamp") + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" # https://docs.snowflake.com/en/sql-reference/functions/date_part.html @@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser] class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax - RESOLVES_IDENTIFIERS_AS_UPPERCASE = True + NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False @@ -236,6 +241,18 @@ class Snowflake(Dialect): "ff6": "%f", } + def quote_identifier(self, expression: E, identify: bool = True) -> E: + # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an + # unquoted DUAL keyword in a special way and does not map it to a user-defined table + if ( + isinstance(expression, exp.Identifier) + and isinstance(expression.parent, exp.Table) + and expression.name.lower() == "dual" + ): + return t.cast(E, expression) + + return super().quote_identifier(expression, identify=identify) + class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True @@ -245,6 +262,9 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "ARRAY_CONTAINS": lambda args: exp.ArrayContains( + this=seq_get(args, 1), expression=seq_get(args, 0) + ), "ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries( # ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive start=seq_get(args, 0), @@ -296,8 +316,8 @@ class Snowflake(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, - TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), - TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), + TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), + TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { @@ -317,6 +337,11 @@ class Snowflake(Dialect): TokenType.SHOW: lambda self: self._parse_show(), } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "LOCATION": lambda self: self._parse_location(), + } + SHOW_PARSERS = { "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), @@ -349,7 +374,7 @@ class Snowflake(Dialect): table: t.Optional[exp.Expression] = None if self._match_text_seq("@"): table_name = "@" - while True: + while self._curr: self._advance() table_name += self._prev.text if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): @@ -411,6 +436,20 @@ class Snowflake(Dialect): self._match_text_seq("WITH") return self.expression(exp.SwapTable, this=self._parse_table(schema=True)) + def _parse_location(self) -> exp.LocationProperty: + self._match(TokenType.EQ) + + parts = [self._parse_var(any_token=True)] + + while self._match(TokenType.SLASH): + if self._curr and self._prev.end + 1 == self._curr.start: + parts.append(self._parse_var(any_token=True)) + else: + parts.append(exp.Var(this="")) + return self.expression( + exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts)) + ) + class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] @@ -457,6 +496,7 @@ class Snowflake(Dialect): AGGREGATE_FILTER_SUPPORTED = False SUPPORTS_TABLE_COPY = False COLLATE_IS_FUNC = True + LIMIT_ONLY_LITERALS = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -464,15 +504,14 @@ class Snowflake(Dialect): exp.ArgMin: rename_func("MIN_BY"), exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), exp.BitwiseXor: rename_func("BITXOR"), - exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), - exp.DateDiff: lambda self, e: self.func( - "DATEDIFF", e.text("unit"), e.expression, e.this - ), + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -501,10 +540,11 @@ class Snowflake(Dialect): exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, - transforms.explode_to_unnest(0), + transforms.explode_to_unnest(), transforms.eliminate_semi_and_anti_joins, ] ), + exp.SHA: rename_func("SHA1"), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( @@ -524,6 +564,8 @@ class Snowflake(Dialect): exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), + exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -547,6 +589,20 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def trycast_sql(self, expression: exp.TryCast) -> str: + value = expression.this + + if value.type is None: + from sqlglot.optimizer.annotate_types import annotate_types + + value = annotate_types(value) + + if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN): + return super().trycast_sql(expression) + + # TRY_CAST only works for string values in Snowflake + return self.cast_sql(expression) + def log_sql(self, expression: exp.Log) -> str: if not expression.expression: return self.func("LN", expression.this) @@ -554,24 +610,28 @@ class Snowflake(Dialect): return super().log_sql(expression) def unnest_sql(self, expression: exp.Unnest) -> str: - selects = ["value"] unnest_alias = expression.args.get("alias") - offset = expression.args.get("offset") - if offset: - if unnest_alias: - unnest_alias.append("columns", offset.pop()) - - selects.append("index") - subquery = exp.Subquery( - this=exp.select(*selects).from_( - f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" - ), - ) + columns = [ + exp.to_identifier("seq"), + exp.to_identifier("key"), + exp.to_identifier("path"), + offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"), + seq_get(unnest_alias.columns if unnest_alias else [], 0) + or exp.to_identifier("value"), + exp.to_identifier("this"), + ] + + if unnest_alias: + unnest_alias.set("columns", columns) + else: + unnest_alias = exp.TableAlias(this="_u", columns=columns) + + explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" alias = self.sql(unnest_alias) alias = f" AS {alias}" if alias else "" - return f"{self.sql(subquery)}{alias}" + return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: scope = self.sql(expression, "scope") @@ -632,3 +692,6 @@ class Snowflake(Dialect): def swaptable_sql(self, expression: exp.SwapTable) -> str: this = self.sql(expression, "this") return f"SWAP WITH {this}" + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ") |