diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dialects/snowflake.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 78 |
1 files changed, 68 insertions, 10 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9733a85..8d8183c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: +def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: +def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() if not this: @@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return "ARRAY" elif expression.is_type("map"): @@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: + flag = expression.text("flag") + + if "i" not in flag: + flag += "i" + + return self.func( + "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag) + ) + + def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) @@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: return regexp_replace +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]: + def _parse(self: Snowflake.Parser) -> exp.Show: + return self._parse_show_snowflake(*args, **kwargs) + + return _parse + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax RESOLVES_IDENTIFIERS_AS_UPPERCASE = True @@ -216,6 +234,7 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -230,6 +249,7 @@ class Snowflake(Dialect): "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, "REGEXP_REPLACE": _parse_regexp_replace, @@ -250,11 +270,6 @@ class Snowflake(Dialect): } FUNCTION_PARSERS.pop("TRIM") - FUNC_TOKENS = { - *parser.Parser.FUNC_TOKENS, - TokenType.TABLE, - } - COLUMN_OPERATORS = { **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -281,6 +296,16 @@ class Snowflake(Dialect): ), } + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + } + + SHOW_PARSERS = { + "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + } + def _parse_id_var( self, any_token: bool = True, @@ -296,8 +321,24 @@ class Snowflake(Dialect): return super()._parse_id_var(any_token=any_token, tokens=tokens) + def _parse_show_snowflake(self, this: str) -> exp.Show: + scope = None + scope_kind = None + + if self._match(TokenType.IN): + if self._match_text_seq("ACCOUNT"): + scope_kind = "ACCOUNT" + elif self._match_set(self.DB_CREATABLES): + scope_kind = self._prev.text + if self._curr: + scope = self._parse_table() + elif self._curr: + scope_kind = "TABLE" + scope = self._parse_table() + + return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + class Tokenizer(tokens.Tokenizer): - QUOTES = ["'"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] RAW_STRINGS = ["$$"] @@ -331,6 +372,8 @@ class Snowflake(Dialect): VAR_SINGLE_TOKENS = {"$"} + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False @@ -355,6 +398,7 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.Extract: rename_func("DATE_PART"), + exp.GroupConcat: rename_func("LISTAGG"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -362,6 +406,7 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.RegexpILike: _regexpilike_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), @@ -373,6 +418,7 @@ class Snowflake(Dialect): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.Stuff: rename_func("INSERT"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -403,6 +449,16 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def show_sql(self, expression: exp.Show) -> str: + scope = self.sql(expression, "scope") + scope = f" {scope}" if scope else "" + + scope_kind = self.sql(expression, "scope_kind") + if scope_kind: + scope_kind = f" IN {scope_kind}" + + return f"SHOW {expression.name}{scope_kind}{scope}" + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to # generate default values as necessary to ensure the transpilation is correct @@ -436,7 +492,9 @@ class Snowflake(Dialect): kind_value = expression.args.get("kind") or "TABLE" kind = f" {kind_value}" if kind_value else "" this = f" {self.sql(expression, 'this')}" - return f"DESCRIBE{kind}{this}" + expressions = self.expressions(expression, flat=True) + expressions = f" {expressions}" if expressions else "" + return f"DESCRIBE{kind}{this}{expressions}" def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint |