summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
commitf1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch)
tree5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglot/dialects/snowflake.py
parentReleasing debian version 20.1.0-1. (diff)
downloadsqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz
sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py80
1 files changed, 51 insertions, 29 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index cdbc071..f09a990 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -293,7 +293,6 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
- "TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@@ -369,36 +368,58 @@ class Snowflake(Dialect):
return lateral
+ def _parse_at_before(self, table: exp.Table) -> exp.Table:
+ # https://docs.snowflake.com/en/sql-reference/constructs/at-before
+ index = self._index
+ if self._match_texts(("AT", "BEFORE")):
+ this = self._prev.text.upper()
+ kind = (
+ self._match(TokenType.L_PAREN)
+ and self._match_texts(self.HISTORICAL_DATA_KIND)
+ and self._prev.text.upper()
+ )
+ expression = self._match(TokenType.FARROW) and self._parse_bitwise()
+
+ if expression:
+ self._match_r_paren()
+ when = self.expression(
+ exp.HistoricalData, this=this, kind=kind, expression=expression
+ )
+ table.set("when", when)
+ else:
+ self._retreat(index)
+
+ return table
+
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
- table: t.Optional[exp.Expression] = None
- if self._match_text_seq("@"):
- table_name = "@"
- while self._curr:
- self._advance()
- table_name += self._prev.text
- if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
- break
- while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
- table_name += self._prev.text
-
- table = exp.var(table_name)
- elif self._match(TokenType.STRING, advance=False):
+ if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
+ elif self._match_text_seq("@", advance=False):
+ table = self._parse_location_path()
+ else:
+ table = None
if table:
file_format = None
pattern = None
- if self._match_text_seq("(", "FILE_FORMAT", "=>"):
- file_format = self._parse_string() or super()._parse_table_parts()
- if self._match_text_seq(",", "PATTERN", "=>"):
+ self._match(TokenType.L_PAREN)
+ while self._curr and not self._match(TokenType.R_PAREN):
+ if self._match_text_seq("FILE_FORMAT", "=>"):
+ file_format = self._parse_string() or super()._parse_table_parts()
+ elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
- self._match_r_paren()
+ else:
+ break
+
+ self._match(TokenType.COMMA)
- return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+ table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+ else:
+ table = super()._parse_table_parts(schema=schema)
- return super()._parse_table_parts(schema=schema)
+ return self._parse_at_before(table)
def _parse_id_var(
self,
@@ -438,17 +459,17 @@ class Snowflake(Dialect):
def _parse_location(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
+ return self.expression(exp.LocationProperty, this=self._parse_location_path())
- parts = [self._parse_var(any_token=True)]
+ def _parse_location_path(self) -> exp.Var:
+ parts = [self._advance_any(ignore_reserved=True)]
- while self._match(TokenType.SLASH):
- if self._curr and self._prev.end + 1 == self._curr.start:
- parts.append(self._parse_var(any_token=True))
- else:
- parts.append(exp.Var(this=""))
- return self.expression(
- exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
- )
+ # We avoid consuming a comma token because external tables like @foo and @bar
+ # can be joined in a query with a comma separator.
+ while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
+ parts.append(self._advance_any(ignore_reserved=True))
+
+ return exp.var("".join(part.text for part in parts if part))
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
@@ -562,6 +583,7 @@ class Snowflake(Dialect):
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
),
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
+ exp.ToArray: rename_func("TO_ARRAY"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),