summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-03 09:12:24 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-03 09:12:24 +0000
commit98d5537435b2951b36c45f1fda667fa27c165794 (patch)
treed26b4dfa6cf91847100fe10a94a04dcc2ad36a86 /sqlglot/parser.py
parentAdding upstream version 11.5.2. (diff)
downloadsqlglot-98d5537435b2951b36c45f1fda667fa27c165794.tar.xz
sqlglot-98d5537435b2951b36c45f1fda667fa27c165794.zip
Adding upstream version 11.7.1.upstream/11.7.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py382
1 files changed, 292 insertions, 90 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index b3b899c..abb23ad 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
+E = t.TypeVar("E", bound=exp.Expression)
+
def parse_var_map(args: t.Sequence) -> exp.Expression:
+ if len(args) == 1 and args[0].is_star:
+ return exp.StarMap(this=args[0])
+
keys = []
values = []
for i in range(0, len(args), 2):
@@ -108,6 +113,8 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
+ JOIN_HINTS: t.Set[str] = set()
+
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
@@ -145,6 +152,7 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATE,
TokenType.DECIMAL,
+ TokenType.BIGDECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
@@ -221,8 +229,10 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
+ TokenType.IS,
TokenType.ISNULL,
TokenType.INTERVAL,
+ TokenType.KEEP,
TokenType.LAZY,
TokenType.LEADING,
TokenType.LEFT,
@@ -235,6 +245,7 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
+ TokenType.OVERWRITE,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
@@ -266,6 +277,8 @@ class Parser(metaclass=_Parser):
*NO_PAREN_FUNCTIONS,
}
+ INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}
+
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.FULL,
@@ -276,6 +289,8 @@ class Parser(metaclass=_Parser):
TokenType.WINDOW,
}
+ COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
+
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
@@ -400,7 +415,7 @@ class Parser(metaclass=_Parser):
COLUMN_OPERATORS = {
TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression(
- exp.Cast,
+ exp.Cast if self.STRICT_CAST else exp.TryCast,
this=this,
to=to,
),
@@ -560,7 +575,7 @@ class Parser(metaclass=_Parser):
),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
@@ -571,7 +586,7 @@ class Parser(metaclass=_Parser):
"FREESPACE": lambda self: self._parse_freespace(),
"GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
@@ -600,20 +615,20 @@ class Parser(metaclass=_Parser):
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
+ "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("STABLE")
+ exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
- "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
+ "STORED": lambda self: self._parse_stored(),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
+ "TEMP": lambda self: self._parse_temporary(global_=False),
"TEMPORARY": lambda self: self._parse_temporary(global_=False),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
- "VOLATILE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
- ),
+ "VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@@ -648,8 +663,11 @@ class Parser(metaclass=_Parser):
"LIKE": lambda self: self._parse_create_like(),
"NOT": lambda self: self._parse_not_constraint(),
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
+ "ON": lambda self: self._match(TokenType.UPDATE)
+ and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
+ "REFERENCES": lambda self: self._parse_references(match=False),
"TITLE": lambda self: self.expression(
exp.TitleColumnConstraint, this=self._parse_var_or_string()
),
@@ -668,9 +686,14 @@ class Parser(metaclass=_Parser):
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
+ TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
TokenType.CASE: lambda self: self._parse_case(),
TokenType.IF: lambda self: self._parse_if(),
- TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
+ TokenType.NEXT_VALUE_FOR: lambda self: self.expression(
+ exp.NextValueFor,
+ this=self._parse_column(),
+ order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
+ ),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -715,6 +738,8 @@ class Parser(metaclass=_Parser):
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
+ TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {}
+
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
@@ -731,6 +756,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
+ WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -738,6 +764,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
+ QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
+ PREFIXED_PIVOT_COLUMNS = False
+
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
@@ -895,8 +924,8 @@ class Parser(metaclass=_Parser):
error level setting.
"""
token = token or self._curr or self._prev or Token.string("")
- start = self._find_token(token)
- end = start + len(token.text)
+ start = token.start
+ end = token.end
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
@@ -918,8 +947,8 @@ class Parser(metaclass=_Parser):
self.errors.append(error)
def expression(
- self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
- ) -> exp.Expression:
+ self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs
+ ) -> E:
"""
Creates a new, validated Expression.
@@ -958,22 +987,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
- return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
-
- def _find_token(self, token: Token) -> int:
- line = 1
- col = 1
- index = 0
-
- while line < token.line or col < token.col:
- if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK:
- line += 1
- col = 1
- else:
- col += 1
- index += 1
-
- return index
+ return self.sql[start.start : end.end]
def _advance(self, times: int = 1) -> None:
self._index += times
@@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
- def _parse_command(self) -> exp.Expression:
+ def _parse_command(self) -> exp.Command:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
@@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser):
if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=kind.token_type)
elif kind.token_type == TokenType.TABLE:
- this = self._parse_table()
+ this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS)
elif kind.token_type == TokenType.COLUMN:
this = self._parse_column()
else:
@@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
- def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
+ def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
- if default_kind:
- kind = default_kind
- else:
- return self._parse_as_command(start)
+ return self._parse_as_command(start)
return self.expression(
exp.Drop,
@@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser):
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
constraints=self._match_text_seq("CONSTRAINTS"),
+ purge=self._match_text_seq("PURGE"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser):
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
- volatile = self._match(TokenType.VOLATILE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser):
kind=create_token.text,
replace=replace,
unique=unique,
- volatile=volatile,
expression=expression,
exists=exists,
properties=properties,
@@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser):
return None
+ def _parse_stored(self) -> exp.Expression:
+ self._match(TokenType.ALIAS)
+
+ input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
+ output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None
+
+ return self.expression(
+ exp.FileFormatProperty,
+ this=self.expression(
+ exp.InputOutputFormat, input_format=input_format, output_format=output_format
+ )
+ if input_format or output_format
+ else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ )
+
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
@@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser):
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
+ def _parse_volatile_property(self) -> exp.Expression:
+ if self._index >= 2:
+ pre_volatile_token = self._tokens[self._index - 2]
+ else:
+ pre_volatile_token = None
+
+ if pre_volatile_token and pre_volatile_token.token_type in (
+ TokenType.CREATE,
+ TokenType.REPLACE,
+ TokenType.UNIQUE,
+ ):
+ return exp.VolatileProperty()
+
+ return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
+
def _parse_with_property(
self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
@@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser):
exists=self._parse_exists(),
partition=self._parse_partition(),
expression=self._parse_ddl_select(),
+ conflict=self._parse_on_conflict(),
returning=self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
)
+ def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
+ conflict = self._match_text_seq("ON", "CONFLICT")
+ duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
+
+ if not (conflict or duplicate):
+ return None
+
+ nothing = None
+ expressions = None
+ key = None
+ constraint = None
+
+ if conflict:
+ if self._match_text_seq("ON", "CONSTRAINT"):
+ constraint = self._parse_id_var()
+ else:
+ key = self._parse_csv(self._parse_value)
+
+ self._match_text_seq("DO")
+ if self._match_text_seq("NOTHING"):
+ nothing = True
+ else:
+ self._match(TokenType.UPDATE)
+ expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
+
+ return self.expression(
+ exp.OnConflict,
+ duplicate=duplicate,
+ expressions=expressions,
+ nothing=nothing,
+ key=key,
+ constraint=constraint,
+ )
+
def _parse_returning(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.RETURNING):
return None
@@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
- this=self._parse_table(schema=True),
+ this=self._parse_table(),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
returning=self._parse_returning(),
@@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser):
if not skip_with_token and not self._match(TokenType.WITH):
return None
+ comments = self._prev_comments
recursive = self._match(TokenType.RECURSIVE)
expressions = []
@@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
- return self.expression(exp.With, expressions=expressions, recursive=recursive)
+ return self.expression(
+ exp.With, comments=comments, expressions=expressions, recursive=recursive
+ )
def _parse_cte(self) -> exp.Expression:
alias = self._parse_table_alias()
@@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser):
table = isinstance(this, exp.Table)
while True:
- lateral = self._parse_lateral()
join = self._parse_join()
- comma = None if table else self._match(TokenType.COMMA)
- if lateral:
- this.append("laterals", lateral)
if join:
this.append("joins", join)
+
+ lateral = None
+ if not join:
+ lateral = self._parse_lateral()
+ if lateral:
+ this.append("laterals", lateral)
+
+ comma = None if table else self._match(TokenType.COMMA)
if comma:
this.args["from"].append("expressions", self._parse_table())
+
if not (lateral or join or comma):
break
@@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser):
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
+
self._match_l_paren()
partition = self._parse_partition_by()
order = self._parse_order()
measures = (
- self._parse_alias(self._parse_conjunction())
- if self._match_text_seq("MEASURES")
- else None
+ self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
@@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser):
pattern = None
define = (
- self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None
+ self._parse_csv(
+ lambda: self.expression(
+ exp.Alias,
+ alias=self._parse_id_var(any_token=True),
+ this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
+ )
+ )
+ if self._match_text_seq("DEFINE")
+ else None
)
+
self._match_r_paren()
return self.expression(
@@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser):
after=after,
pattern=pattern,
define=define,
+ alias=self._parse_table_alias(),
)
def _parse_lateral(self) -> t.Optional[exp.Expression]:
@@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser):
alias=table_alias,
)
- if outer_apply or cross_apply:
- return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
-
return expression
def _parse_join_side_and_kind(
@@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser):
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ index = self._index
natural, side, kind = self._parse_join_side_and_kind()
+ hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
+ join = self._match(TokenType.JOIN)
- if not skip_join_token and not self._match(TokenType.JOIN):
+ if not skip_join_token and not join:
+ self._retreat(index)
+ kind = None
+ natural = None
+ side = None
+
+ outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
+ cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False)
+
+ if not skip_join_token and not join and not outer_apply and not cross_apply:
return None
+ if outer_apply:
+ side = Token(TokenType.LEFT, "LEFT")
+
kwargs: t.Dict[
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
] = {"this": self._parse_table()}
@@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser):
kwargs["side"] = side.text
if kind:
kwargs["kind"] = kind.text
+ if hint:
+ kwargs["hint"] = hint
if self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
@@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser):
return None
expressions = self._parse_wrapped_csv(self._parse_column)
- ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
+ ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser):
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
- offset = self._parse_conjunction()
+ offset = self._parse_id_var() or exp.Identifier(this="offset")
return self.expression(
exp.Unnest,
@@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
+ if not expressions:
+ self.raise_error("Failed to parse PIVOT's aggregation list")
+
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
@@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser):
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
+ if not unpivot:
+ names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
+
+ columns: t.List[exp.Expression] = []
+ for col in pivot.args["field"].expressions:
+ for name in names:
+ if self.PREFIXED_PIVOT_COLUMNS:
+ name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
+ else:
+ name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
+
+ columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
+
+ pivot.set("columns", columns)
+
return pivot
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ return [agg.alias for agg in pivot_columns]
+
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
+
count = self._parse_number()
+ percent = self._match(TokenType.PERCENT)
+
self._match_set((TokenType.ROW, TokenType.ROWS))
- self._match(TokenType.ONLY)
- return self.expression(exp.Fetch, direction=direction, count=count)
+
+ only = self._match(TokenType.ONLY)
+ with_ties = self._match_text_seq("WITH", "TIES")
+
+ if only and with_ties:
+ self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause")
+
+ return self.expression(
+ exp.Fetch,
+ direction=direction,
+ count=count,
+ percent=percent,
+ with_ties=with_ties,
+ )
return this
@@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser):
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
- this = self.RANGE_PARSERS[self._prev.token_type](self, this)
+ expression = self.RANGE_PARSERS[self._prev.token_type](self, this)
+ if not expression:
+ return this
+
+ this = expression
elif self._match(TokenType.ISNULL):
this = self.expression(exp.Is, this=this, expression=exp.Null())
@@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser):
return this
- def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
+ def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ index = self._index - 1
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
- this = self.expression(
- exp.Is,
- this=this,
- expression=self._parse_null() or self._parse_boolean(),
- )
+ expression = self._parse_null() or self._parse_boolean()
+ if not expression:
+ self._retreat(index)
+ return None
+
+ this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
@@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
+ def _parse_interval(self) -> t.Optional[exp.Expression]:
+ if not self._match(TokenType.INTERVAL):
+ return None
+
+ this = self._parse_primary() or self._parse_term()
+ unit = self._parse_function() or self._parse_var()
+
+ # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
+ # each INTERVAL expression into this canonical form so it's easy to transpile
+ if this and isinstance(this, exp.Literal):
+ if this.is_number:
+ this = exp.Literal.string(this.name)
+
+ # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
+ parts = this.name.split()
+ if not unit and len(parts) <= 2:
+ this = exp.Literal.string(seq_get(parts, 0))
+ unit = self.expression(exp.Var, this=seq_get(parts, 1))
+
+ return self.expression(exp.Interval, this=this, unit=unit)
+
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
@@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser):
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.INTERVAL):
- return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field())
+ interval = self._parse_interval()
+ if interval:
+ return interval
index = self._index
- type_token = self._parse_types(check_func=True)
+ data_type = self._parse_types(check_func=True)
this = self._parse_column()
- if type_token:
+ if data_type:
if isinstance(this, exp.Literal):
- return self.expression(exp.Cast, this=this, to=type_token)
- if not type_token.args.get("expressions"):
+ parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
+ if parser:
+ return parser(self, this, data_type)
+ return self.expression(exp.Cast, this=this, to=data_type)
+ if not data_type.args.get("expressions"):
self._retreat(index)
return self._parse_column()
- return type_token
+ return data_type
return this
@@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(self._parse_conjunction)
- if not expressions:
+ if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
return None
- self._match_r_paren()
maybe_func = True
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
@@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser):
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
- if self._curr and self._curr.token_type in self.TYPE_TOKENS:
- return self._parse_types()
-
+ index = self._index
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
if not data_type:
- return None
+ self._retreat(index)
+ return self._parse_types()
return self.expression(exp.StructKwarg, this=this, expression=data_type)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
self._match_r_paren()
+ comments.extend(self._prev_comments)
if this and comments:
this.comments = comments
@@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser):
return None
- def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
- return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
+ def _parse_field(
+ self,
+ any_token: bool = False,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ return (
+ self._parse_primary()
+ or self._parse_function()
+ or self._parse_id_var(any_token=any_token, tokens=tokens)
+ )
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
@@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser):
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
- this = self._parse_references()
- if this:
- return this
-
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
+ else:
+ this = None
if self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression(
@@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.REFERENCES):
+ def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
+ if match and not self._match(TokenType.REFERENCES):
return None
expressions = None
@@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
- expressions = apply_index_offset(expressions, -self.index_offset)
+ expressions = apply_index_offset(this, expressions, -self.index_offset)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
@@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser):
self.validate_expression(this, args)
self._match_r_paren()
else:
+ index = self._index - 1
condition = self._parse_conjunction()
+
+ if not condition:
+ self._retreat(index)
+ return None
+
self._match(TokenType.THEN)
true = self._parse_conjunction()
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
@@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser):
# bigquery select from window x AS (partition by ...)
if alias:
+ over = None
self._match(TokenType.ALIAS)
- elif not self._match(TokenType.OVER):
+ elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS):
return this
+ else:
+ over = self._prev.text.upper()
if not self._match(TokenType.L_PAREN):
- return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
+ return self.expression(
+ exp.Window, this=this, alias=self._parse_id_var(False), over=over
+ )
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
+
+ first = self._match(TokenType.FIRST)
+ if self._match_text_seq("LAST"):
+ first = False
+
partition = self._parse_partition_by()
order = self._parse_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
@@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser):
order=order,
spec=spec,
alias=window_alias,
+ over=over,
+ first=first,
)
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
@@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser):
return expression
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
- return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
+ drop = self._match(TokenType.DROP) and self._parse_drop()
+ if drop and not isinstance(drop, exp.Command):
+ drop.set("kind", drop.args.get("kind", "COLUMN"))
+ return drop
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
@@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
- then = self.expression(exp.Insert, this=_this)
+ then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
@@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser):
break
parent = parent.parent
else:
- column.replace(dot_or_id)
+ if column is node:
+ node = dot_or_id
+ else:
+ column.replace(dot_or_id)
return node