summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dialects/snowflake.py
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py78
1 files changed, 68 insertions, 10 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 9733a85..8d8183c 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
+def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) ->
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
+def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
@@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return "ARRAY"
elif expression.is_type("map"):
@@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
+ flag = expression.text("flag")
+
+ if "i" not in flag:
+ flag += "i"
+
+ return self.func(
+ "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag)
+ )
+
+
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
@@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
return regexp_replace
+def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]:
+ def _parse(self: Snowflake.Parser) -> exp.Show:
+ return self._parse_show_snowflake(*args, **kwargs)
+
+ return _parse
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@@ -216,6 +234,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -230,6 +249,7 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
+ "LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
@@ -250,11 +270,6 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS.pop("TRIM")
- FUNC_TOKENS = {
- *parser.Parser.FUNC_TOKENS,
- TokenType.TABLE,
- }
-
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@@ -281,6 +296,16 @@ class Snowflake(Dialect):
),
}
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.SHOW: lambda self: self._parse_show(),
+ }
+
+ SHOW_PARSERS = {
+ "PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ }
+
def _parse_id_var(
self,
any_token: bool = True,
@@ -296,8 +321,24 @@ class Snowflake(Dialect):
return super()._parse_id_var(any_token=any_token, tokens=tokens)
+ def _parse_show_snowflake(self, this: str) -> exp.Show:
+ scope = None
+ scope_kind = None
+
+ if self._match(TokenType.IN):
+ if self._match_text_seq("ACCOUNT"):
+ scope_kind = "ACCOUNT"
+ elif self._match_set(self.DB_CREATABLES):
+ scope_kind = self._prev.text
+ if self._curr:
+ scope = self._parse_table()
+ elif self._curr:
+ scope_kind = "TABLE"
+ scope = self._parse_table()
+
+ return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
RAW_STRINGS = ["$$"]
@@ -331,6 +372,8 @@ class Snowflake(Dialect):
VAR_SINGLE_TOKENS = {"$"}
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
+
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
@@ -355,6 +398,7 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
+ exp.GroupConcat: rename_func("LISTAGG"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -362,6 +406,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
@@ -373,6 +418,7 @@ class Snowflake(Dialect):
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
+ exp.Stuff: rename_func("INSERT"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@@ -403,6 +449,16 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def show_sql(self, expression: exp.Show) -> str:
+ scope = self.sql(expression, "scope")
+ scope = f" {scope}" if scope else ""
+
+ scope_kind = self.sql(expression, "scope_kind")
+ if scope_kind:
+ scope_kind = f" IN {scope_kind}"
+
+ return f"SHOW {expression.name}{scope_kind}{scope}"
+
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
# generate default values as necessary to ensure the transpilation is correct
@@ -436,7 +492,9 @@ class Snowflake(Dialect):
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
- return f"DESCRIBE{kind}{this}"
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" {expressions}" if expressions else ""
+ return f"DESCRIBE{kind}{this}{expressions}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint