diff options
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 53 |
1 files changed, 34 insertions, 19 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 281167d..37f9761 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -21,19 +21,13 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import seq_get +from sqlglot.helper import is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: from sqlglot._typing import E -def _check_int(s: str) -> bool: - if s[0] in ("-", "+"): - return s[1:].isdigit() - return s.isdigit() - - # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: @@ -53,7 +47,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, return exp.TimeStrToTime.from_arg_list(args) if first_arg.is_string: - if _check_int(first_arg.this): + if is_int(first_arg.this): # case: <integer> return exp.UnixToTime.from_arg_list(args) @@ -241,7 +235,6 @@ DATE_PART_MAPPING = { "NSECOND": "NANOSECOND", "NSECONDS": "NANOSECOND", "NANOSECS": "NANOSECOND", - "NSECONDS": "NANOSECOND", "EPOCH": "EPOCH_SECOND", "EPOCH_SECONDS": "EPOCH_SECOND", "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", @@ -291,7 +284,9 @@ def _parse_colon_get_path( path = exp.Literal.string(path.sql(dialect="snowflake")) # The extraction operator : is left-associative - this = self.expression(exp.GetPath, this=this, expression=path) + this = self.expression( + exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + ) if target_type: this = exp.cast(this, target_type) @@ -411,6 +406,9 @@ class Snowflake(Dialect): "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "FLATTEN": exp.Explode.from_arg_list, + "GET_PATH": lambda args, dialect: exp.JSONExtract( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ), "IFF": exp.If.from_arg_list, "LAST_DAY": lambda args: exp.LastDay( this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) @@ -474,6 +472,8 @@ class Snowflake(Dialect): "TERSE SCHEMAS": _show_parser("SCHEMAS"), "OBJECTS": _show_parser("OBJECTS"), "TERSE OBJECTS": _show_parser("OBJECTS"), + "TABLES": _show_parser("TABLES"), + "TERSE TABLES": _show_parser("TABLES"), "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), @@ -534,7 +534,9 @@ class Snowflake(Dialect): return table - def _parse_table_parts(self, schema: bool = False) -> exp.Table: + def _parse_table_parts( + self, schema: bool = False, is_db_reference: bool = False + ) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage if self._match(TokenType.STRING, advance=False): table = self._parse_string() @@ -550,7 +552,9 @@ class Snowflake(Dialect): self._match(TokenType.L_PAREN) while self._curr and not self._match(TokenType.R_PAREN): if self._match_text_seq("FILE_FORMAT", "=>"): - file_format = self._parse_string() or super()._parse_table_parts() + file_format = self._parse_string() or super()._parse_table_parts( + is_db_reference=is_db_reference + ) elif self._match_text_seq("PATTERN", "=>"): pattern = self._parse_string() else: @@ -560,7 +564,7 @@ class Snowflake(Dialect): table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern) else: - table = super()._parse_table_parts(schema=schema) + table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference) return self._parse_at_before(table) @@ -587,6 +591,8 @@ class Snowflake(Dialect): # which is syntactically valid but has no effect on the output terse = self._tokens[self._index - 2].text.upper() == "TERSE" + history = self._match_text_seq("HISTORY") + like = self._parse_string() if self._match(TokenType.LIKE) else None if self._match(TokenType.IN): @@ -597,7 +603,7 @@ class Snowflake(Dialect): if self._curr: scope = self._parse_table_parts() elif self._curr: - scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE" + scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE" scope = self._parse_table_parts() return self.expression( @@ -605,6 +611,7 @@ class Snowflake(Dialect): **{ "terse": terse, "this": this, + "history": history, "like": like, "scope": scope, "scope_kind": scope_kind, @@ -715,8 +722,10 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), - exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]", + exp.JSONExtract: rename_func("GET_PATH"), + exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"), exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), + exp.JSONPathRoot: lambda *_: "", exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -745,7 +754,8 @@ class Snowflake(Dialect): exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Struct: lambda self, e: self.func( "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), @@ -771,6 +781,12 @@ class Snowflake(Dialect): exp.Xor: rename_func("BOOLXOR"), } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", @@ -841,6 +857,7 @@ class Snowflake(Dialect): def show_sql(self, expression: exp.Show) -> str: terse = "TERSE " if expression.args.get("terse") else "" + history = " HISTORY" if expression.args.get("history") else "" like = self.sql(expression, "like") like = f" LIKE {like}" if like else "" @@ -861,9 +878,7 @@ class Snowflake(Dialect): if from_: from_ = f" FROM {from_}" - return ( - f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" - ) + return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to |