diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-18 05:35:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-18 05:35:55 +0000 |
commit | fe979e8421c04c038353a0a2d07d81779516186a (patch) | |
tree | efb70a52261e5cf4862a7eb69e1d7cd16356fcba /sqlglot/dialects/snowflake.py | |
parent | Releasing debian version 23.13.7-1. (diff) | |
download | sqlglot-fe979e8421c04c038353a0a2d07d81779516186a.tar.xz sqlglot-fe979e8421c04c038353a0a2d07d81779516186a.zip |
Merging upstream version 23.16.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 78 |
1 files changed, 56 insertions, 22 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5f1e052..2e8a647 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, binary_from_function, + build_default_decimal_type, date_delta_sql, date_trunc_to_time, datestrtodate_sql, @@ -334,6 +335,7 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + DEFAULT_SAMPLING_METHOD = "BERNOULLI" ID_VAR_TOKENS = { *parser.Parser.ID_VAR_TOKENS, @@ -345,6 +347,7 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_CONTAINS": lambda args: exp.ArrayContains( @@ -423,7 +426,6 @@ class Snowflake(Dialect): ALTER_PARSERS = { **parser.Parser.ALTER_PARSERS, - "SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")), "UNSET": lambda self: self.expression( exp.Set, tag=self._match_text_seq("TAG"), @@ -443,6 +445,11 @@ class Snowflake(Dialect): "LOCATION": lambda self: self._parse_location_property(), } + TYPE_CONVERTER = { + # https://docs.snowflake.com/en/sql-reference/data-types-numeric#number + exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0), + } + SHOW_PARSERS = { "SCHEMAS": _show_parser("SCHEMAS"), "TERSE SCHEMAS": _show_parser("SCHEMAS"), @@ -475,6 +482,14 @@ class Snowflake(Dialect): SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} + def _parse_create(self) -> exp.Create | exp.Command: + expression = super()._parse_create() + if isinstance(expression, exp.Create) and expression.kind == "TAG": + # Replace the Table node with the enclosed Identifier + expression.this.replace(expression.this.this) + + return expression + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = super()._parse_column_ops(this) @@ -600,8 +615,8 @@ class Snowflake(Dialect): file_format = None pattern = None - self._match(TokenType.L_PAREN) - while self._curr and not self._match(TokenType.R_PAREN): + wrapped = self._match(TokenType.L_PAREN) + while self._curr and wrapped and not self._match(TokenType.R_PAREN): if self._match_text_seq("FILE_FORMAT", "=>"): file_format = self._parse_string() or super()._parse_table_parts( is_db_reference=is_db_reference @@ -681,14 +696,22 @@ class Snowflake(Dialect): return self.expression(exp.LocationProperty, this=self._parse_location_path()) def _parse_file_location(self) -> t.Optional[exp.Expression]: - return self._parse_table_parts() + # Parse either a subquery or a staged file + return ( + self._parse_select(table=True) + if self._match(TokenType.L_PAREN, advance=False) + else self._parse_table_parts() + ) def _parse_location_path(self) -> exp.Var: parts = [self._advance_any(ignore_reserved=True)] # We avoid consuming a comma token because external tables like @foo and @bar - # can be joined in a query with a comma separator. - while self._is_connected() and not self._match(TokenType.COMMA, advance=False): + # can be joined in a query with a comma separator, as well as closing paren + # in case of subqueries + while self._is_connected() and not self._match_set( + (TokenType.COMMA, TokenType.R_PAREN), advance=False + ): parts.append(self._advance_any(ignore_reserved=True)) return exp.var("".join(part.text for part in parts if part)) @@ -713,12 +736,12 @@ class Snowflake(Dialect): "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, "REMOVE": TokenType.COMMAND, - "RENAME": TokenType.REPLACE, "RM": TokenType.COMMAND, "SAMPLE": TokenType.TABLE_SAMPLE, "SQL_DOUBLE": TokenType.DOUBLE, "SQL_VARCHAR": TokenType.VARCHAR, "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION, + "TAG": TokenType.TAG, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TOP": TokenType.TOP, } @@ -748,6 +771,7 @@ class Snowflake(Dialect): STRUCT_DELIMITER = ("(", ")") COPY_PARAMS_ARE_WRAPPED = False COPY_PARAMS_EQ_REQUIRED = True + STAR_EXCEPT = "EXCLUDE" TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -818,7 +842,7 @@ class Snowflake(Dialect): exp.TimestampDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.unit, e.expression, e.this ), - exp.TimestampTrunc: timestamptrunc_sql, + exp.TimestampTrunc: timestamptrunc_sql(), exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( "TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e) @@ -850,11 +874,6 @@ class Snowflake(Dialect): exp.DataType.Type.STRUCT: "OBJECT", } - STAR_MAPPING = { - "except": "EXCLUDE", - "replace": "RENAME", - } - PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, exp.SetProperty: exp.Properties.Location.UNSUPPORTED, @@ -862,9 +881,15 @@ class Snowflake(Dialect): } UNSUPPORTED_VALUES_EXPRESSIONS = { + exp.Map, + exp.StarMap, exp.Struct, + exp.VarMap, } + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, wrapped=False, prefix=self.sep(""), sep=" ") + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS): values_as_table = False @@ -1019,9 +1044,6 @@ class Snowflake(Dialect): this = self.sql(expression, "this") return f"SWAP WITH {this}" - def with_properties(self, properties: exp.Properties) -> str: - return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ") - def cluster_sql(self, expression: exp.Cluster) -> str: return f"CLUSTER BY ({self.expressions(expression, flat=True)})" @@ -1041,10 +1063,22 @@ class Snowflake(Dialect): return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values))) - def copyparameter_sql(self, expression: exp.CopyParameter) -> str: - option = self.sql(expression, "this").upper() - if option == "FILE_FORMAT": - values = self.expressions(expression, key="expression", flat=True, sep=" ") - return f"{option} = ({values})" + def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str: + if expression.args.get("weight") or expression.args.get("accuracy"): + self.unsupported( + "APPROX_PERCENTILE with weight and/or accuracy arguments are not supported in Snowflake" + ) + + return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile")) + + def alterset_sql(self, expression: exp.AlterSet) -> str: + exprs = self.expressions(expression, flat=True) + exprs = f" {exprs}" if exprs else "" + file_format = self.expressions(expression, key="file_format", flat=True, sep=" ") + file_format = f" STAGE_FILE_FORMAT = ({file_format})" if file_format else "" + copy_options = self.expressions(expression, key="copy_options", flat=True, sep=" ") + copy_options = f" STAGE_COPY_OPTIONS = ({copy_options})" if copy_options else "" + tag = self.expressions(expression, key="tag", flat=True) + tag = f" TAG {tag}" if tag else "" - return super().copyparameter_sql(expression) + return f"SET{exprs}{file_format}{copy_options}{tag}" |