diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 80 |
1 files changed, 51 insertions, 29 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index cdbc071..f09a990 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -293,7 +293,6 @@ class Snowflake(Dialect): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, - "TO_ARRAY": exp.Array.from_arg_list, "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, @@ -369,36 +368,58 @@ class Snowflake(Dialect): return lateral + def _parse_at_before(self, table: exp.Table) -> exp.Table: + # https://docs.snowflake.com/en/sql-reference/constructs/at-before + index = self._index + if self._match_texts(("AT", "BEFORE")): + this = self._prev.text.upper() + kind = ( + self._match(TokenType.L_PAREN) + and self._match_texts(self.HISTORICAL_DATA_KIND) + and self._prev.text.upper() + ) + expression = self._match(TokenType.FARROW) and self._parse_bitwise() + + if expression: + self._match_r_paren() + when = self.expression( + exp.HistoricalData, this=this, kind=kind, expression=expression + ) + table.set("when", when) + else: + self._retreat(index) + + return table + def _parse_table_parts(self, schema: bool = False) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage - table: t.Optional[exp.Expression] = None - if self._match_text_seq("@"): - table_name = "@" - while self._curr: - self._advance() - table_name += self._prev.text - if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): - break - while self._match_set(self.STAGED_FILE_SINGLE_TOKENS): - table_name += self._prev.text - - table = exp.var(table_name) - elif self._match(TokenType.STRING, advance=False): + if self._match(TokenType.STRING, advance=False): table = self._parse_string() + elif self._match_text_seq("@", advance=False): + table = self._parse_location_path() + else: + table = None if table: file_format = None pattern = None - if self._match_text_seq("(", "FILE_FORMAT", "=>"): - file_format = self._parse_string() or super()._parse_table_parts() - if self._match_text_seq(",", "PATTERN", "=>"): + self._match(TokenType.L_PAREN) + while self._curr and not self._match(TokenType.R_PAREN): + if self._match_text_seq("FILE_FORMAT", "=>"): + file_format = self._parse_string() or super()._parse_table_parts() + elif self._match_text_seq("PATTERN", "=>"): pattern = self._parse_string() - self._match_r_paren() + else: + break + + self._match(TokenType.COMMA) - return self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + else: + table = super()._parse_table_parts(schema=schema) - return super()._parse_table_parts(schema=schema) + return self._parse_at_before(table) def _parse_id_var( self, @@ -438,17 +459,17 @@ class Snowflake(Dialect): def _parse_location(self) -> exp.LocationProperty: self._match(TokenType.EQ) + return self.expression(exp.LocationProperty, this=self._parse_location_path()) - parts = [self._parse_var(any_token=True)] + def _parse_location_path(self) -> exp.Var: + parts = [self._advance_any(ignore_reserved=True)] - while self._match(TokenType.SLASH): - if self._curr and self._prev.end + 1 == self._curr.start: - parts.append(self._parse_var(any_token=True)) - else: - parts.append(exp.Var(this="")) - return self.expression( - exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts)) - ) + # We avoid consuming a comma token because external tables like @foo and @bar + # can be joined in a query with a comma separator. + while self._is_connected() and not self._match(TokenType.COMMA, advance=False): + parts.append(self._advance_any(ignore_reserved=True)) + + return exp.var("".join(part.text for part in parts if part)) class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\", "'"] @@ -562,6 +583,7 @@ class Snowflake(Dialect): "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) ), exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", + exp.ToArray: rename_func("TO_ARRAY"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), |