summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py652
1 files changed, 453 insertions, 199 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 308f363..bd95db8 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -5,7 +5,13 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
+from sqlglot.helper import (
+ apply_index_offset,
+ count_params,
+ ensure_collection,
+ ensure_list,
+ seq_get,
+)
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -54,7 +60,7 @@ class Parser(metaclass=_Parser):
Default: "nulls_are_small"
"""
- FUNCTIONS = {
+ FUNCTIONS: t.Dict[str, t.Callable] = {
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
"DATE_TO_DATE_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
@@ -106,6 +112,7 @@ class Parser(metaclass=_Parser):
TokenType.JSON,
TokenType.JSONB,
TokenType.INTERVAL,
+ TokenType.TIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@@ -164,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
+ TokenType.DIV,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
TokenType.EXECUTE,
@@ -252,6 +260,7 @@ class Parser(metaclass=_Parser):
TokenType.FIRST,
TokenType.FORMAT,
TokenType.IDENTIFIER,
+ TokenType.INDEX,
TokenType.ISNULL,
TokenType.MERGE,
TokenType.OFFSET,
@@ -312,6 +321,7 @@ class Parser(metaclass=_Parser):
}
TIMESTAMPS = {
+ TokenType.TIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@@ -387,6 +397,7 @@ class Parser(metaclass=_Parser):
}
EXPRESSION_PARSERS = {
+ exp.Column: lambda self: self._parse_column(),
exp.DataType: lambda self: self._parse_types(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
@@ -419,6 +430,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
+ TokenType.DESC: lambda self: self._parse_describe(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
@@ -583,6 +595,11 @@ class Parser(metaclass=_Parser):
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+ WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
+
+ # allows tables to have special tokens as prefixes
+ TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
+
STRICT_CAST = True
__slots__ = (
@@ -608,13 +625,13 @@ class Parser(metaclass=_Parser):
def __init__(
self,
- error_level=None,
- error_message_context=100,
- index_offset=0,
- unnest_column_only=False,
- alias_post_tablesample=False,
- max_errors=3,
- null_ordering=None,
+ error_level: t.Optional[ErrorLevel] = None,
+ error_message_context: int = 100,
+ index_offset: int = 0,
+ unnest_column_only: bool = False,
+ alias_post_tablesample: bool = False,
+ max_errors: int = 3,
+ null_ordering: t.Optional[str] = None,
):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
@@ -636,23 +653,43 @@ class Parser(metaclass=_Parser):
self._prev = None
self._prev_comments = None
- def parse(self, raw_tokens, sql=None):
+ def parse(
+ self, raw_tokens: t.List[Token], sql: t.Optional[str] = None
+ ) -> t.List[t.Optional[exp.Expression]]:
"""
- Parses the given list of tokens and returns a list of syntax trees, one tree
+ Parses a list of tokens and returns a list of syntax trees, one tree
per parsed SQL statement.
- Args
- raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`).
- sql (str): the original SQL string. Used to produce helpful debug messages.
+ Args:
+ raw_tokens: the list of tokens.
+ sql: the original SQL string, used to produce helpful debug messages.
- Returns
- the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
+ Returns:
+ The list of syntax trees.
"""
return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
)
- def parse_into(self, expression_types, raw_tokens, sql=None):
+ def parse_into(
+ self,
+ expression_types: str | exp.Expression | t.Collection[exp.Expression | str],
+ raw_tokens: t.List[Token],
+ sql: t.Optional[str] = None,
+ ) -> t.List[t.Optional[exp.Expression]]:
+ """
+ Parses a list of tokens into a given Expression type. If a collection of Expression
+ types is given instead, this method will try to parse the token list into each one
+ of them, stopping at the first for which the parsing succeeds.
+
+ Args:
+ expression_types: the expression type(s) to try and parse the token list into.
+ raw_tokens: the list of tokens.
+ sql: the original SQL string, used to produce helpful debug messages.
+
+ Returns:
+ The target Expression.
+ """
errors = []
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
@@ -668,7 +705,12 @@ class Parser(metaclass=_Parser):
errors=merge_errors(errors),
) from errors[-1]
- def _parse(self, parse_method, raw_tokens, sql=None):
+ def _parse(
+ self,
+ parse_method: t.Callable[[Parser], t.Optional[exp.Expression]],
+ raw_tokens: t.List[Token],
+ sql: t.Optional[str] = None,
+ ) -> t.List[t.Optional[exp.Expression]]:
self.reset()
self.sql = sql or ""
total = len(raw_tokens)
@@ -686,6 +728,7 @@ class Parser(metaclass=_Parser):
self._index = -1
self._tokens = tokens
self._advance()
+
expressions.append(parse_method(self))
if self._index < len(self._tokens):
@@ -695,7 +738,10 @@ class Parser(metaclass=_Parser):
return expressions
- def check_errors(self):
+ def check_errors(self) -> None:
+ """
+ Logs or raises any found errors, depending on the chosen error level setting.
+ """
if self.error_level == ErrorLevel.WARN:
for error in self.errors:
logger.error(str(error))
@@ -705,13 +751,18 @@ class Parser(metaclass=_Parser):
errors=merge_errors(self.errors),
)
- def raise_error(self, message, token=None):
+ def raise_error(self, message: str, token: t.Optional[Token] = None) -> None:
+ """
+ Appends an error in the list of recorded errors or raises it, depending on the chosen
+ error level setting.
+ """
token = token or self._curr or self._prev or Token.string("")
start = self._find_token(token, self.sql)
end = start + len(token.text)
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
+
error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n"
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
@@ -722,11 +773,26 @@ class Parser(metaclass=_Parser):
highlight=highlight,
end_context=end_context,
)
+
if self.error_level == ErrorLevel.IMMEDIATE:
raise error
+
self.errors.append(error)
- def expression(self, exp_class, comments=None, **kwargs):
+ def expression(
+ self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
+ ) -> exp.Expression:
+ """
+ Creates a new, validated Expression.
+
+ Args:
+ exp_class: the expression class to instantiate.
+ comments: an optional list of comments to attach to the expression.
+ kwargs: the arguments to set for the expression along with their respective values.
+
+ Returns:
+ The target expression.
+ """
instance = exp_class(**kwargs)
if self._prev_comments:
instance.comments = self._prev_comments
@@ -736,7 +802,17 @@ class Parser(metaclass=_Parser):
self.validate_expression(instance)
return instance
- def validate_expression(self, expression, args=None):
+ def validate_expression(
+ self, expression: exp.Expression, args: t.Optional[t.List] = None
+ ) -> None:
+ """
+ Validates an already instantiated expression, making sure that all its mandatory arguments
+ are set.
+
+ Args:
+ expression: the expression to validate.
+ args: an optional list of items that was used to instantiate the expression, if it's a Func.
+ """
if self.error_level == ErrorLevel.IGNORE:
return
@@ -748,13 +824,18 @@ class Parser(metaclass=_Parser):
if mandatory and (v is None or (isinstance(v, list) and not v)):
self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
- if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
+ if (
+ args
+ and isinstance(expression, exp.Func)
+ and len(args) > len(expression.arg_types)
+ and not expression.is_var_len_args
+ ):
self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})"
)
- def _find_token(self, token, sql):
+ def _find_token(self, token: Token, sql: str) -> int:
line = 1
col = 1
index = 0
@@ -769,7 +850,7 @@ class Parser(metaclass=_Parser):
return index
- def _advance(self, times=1):
+ def _advance(self, times: int = 1) -> None:
self._index += times
self._curr = seq_get(self._tokens, self._index)
self._next = seq_get(self._tokens, self._index + 1)
@@ -780,10 +861,10 @@ class Parser(metaclass=_Parser):
self._prev = None
self._prev_comments = None
- def _retreat(self, index):
+ def _retreat(self, index: int) -> None:
self._advance(index - self._index)
- def _parse_statement(self):
+ def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -803,7 +884,7 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
- def _parse_drop(self, default_kind=None):
+ def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
@@ -812,7 +893,7 @@ class Parser(metaclass=_Parser):
kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
- return
+ return None
return self.expression(
exp.Drop,
@@ -824,14 +905,14 @@ class Parser(metaclass=_Parser):
cascade=self._match(TokenType.CASCADE),
)
- def _parse_exists(self, not_=False):
+ def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
return (
self._match(TokenType.IF)
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
- def _parse_create(self):
+ def _parse_create(self) -> t.Optional[exp.Expression]:
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match_text_seq("TRANSIENT")
@@ -846,12 +927,16 @@ class Parser(metaclass=_Parser):
if not create_token:
self.raise_error(f"Expected {self.CREATABLES}")
- return
+ return None
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
+ data = None
+ statistics = None
+ no_primary_index = None
+ indexes = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
@@ -868,7 +953,28 @@ class Parser(metaclass=_Parser):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
- expression = self._parse_select(nested=True)
+ expression = self._parse_ddl_select()
+
+ if create_token.token_type == TokenType.TABLE:
+ if self._match_text_seq("WITH", "DATA"):
+ data = True
+ elif self._match_text_seq("WITH", "NO", "DATA"):
+ data = False
+
+ if self._match_text_seq("AND", "STATISTICS"):
+ statistics = True
+ elif self._match_text_seq("AND", "NO", "STATISTICS"):
+ statistics = False
+
+ no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX")
+
+ indexes = []
+ while True:
+ index = self._parse_create_table_index()
+ if not index:
+ break
+ else:
+ indexes.append(index)
return self.expression(
exp.Create,
@@ -883,9 +989,13 @@ class Parser(metaclass=_Parser):
replace=replace,
unique=unique,
materialized=materialized,
+ data=data,
+ statistics=statistics,
+ no_primary_index=no_primary_index,
+ indexes=indexes,
)
- def _parse_property(self):
+ def _parse_property(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
@@ -906,7 +1016,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_property_assignment(self, exp_class):
+ def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(
@@ -914,42 +1024,50 @@ class Parser(metaclass=_Parser):
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
- def _parse_partitioned_by(self):
+ def _parse_partitioned_by(self) -> exp.Expression:
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_distkey(self):
+ def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
- def _parse_create_like(self):
+ def _parse_create_like(self) -> t.Optional[exp.Expression]:
table = self._parse_table(schema=True)
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
+ this = self._prev.text.upper()
+ id_var = self._parse_id_var()
+
+ if not id_var:
+ return None
+
options.append(
self.expression(
exp.Property,
- this=self._prev.text.upper(),
- value=exp.Var(this=self._parse_id_var().this.upper()),
+ this=this,
+ value=exp.Var(this=id_var.this.upper()),
)
)
return self.expression(exp.LikeProperty, this=table, expressions=options)
- def _parse_sortkey(self, compound=False):
+ def _parse_sortkey(self, compound: bool = False) -> exp.Expression:
return self.expression(
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
)
- def _parse_character_set(self, default=False):
+ def _parse_character_set(self, default: bool = False) -> exp.Expression:
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
- def _parse_returns(self):
+ def _parse_returns(self) -> exp.Expression:
+ value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
+
if is_table:
if self._match(TokenType.LT):
value = self.expression(
@@ -960,13 +1078,13 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema("TABLE")
+ value = self._parse_schema(exp.Literal.string("TABLE"))
else:
value = self._parse_types()
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_properties(self):
+ def _parse_properties(self) -> t.Optional[exp.Expression]:
properties = []
while True:
@@ -978,15 +1096,21 @@ class Parser(metaclass=_Parser):
if properties:
return self.expression(exp.Properties, expressions=properties)
+
return None
- def _parse_describe(self):
- self._match(TokenType.TABLE)
- return self.expression(exp.Describe, this=self._parse_id_var())
+ def _parse_describe(self) -> exp.Expression:
+ kind = self._match_set(self.CREATABLES) and self._prev.text
+ this = self._parse_table()
- def _parse_insert(self):
+ return self.expression(exp.Describe, this=this, kind=kind)
+
+ def _parse_insert(self) -> exp.Expression:
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
+
+ this: t.Optional[exp.Expression]
+
if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
@@ -998,21 +1122,22 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)
+
return self.expression(
exp.Insert,
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
- expression=self._parse_select(nested=True),
+ expression=self._parse_ddl_select(),
overwrite=overwrite,
)
- def _parse_row(self):
+ def _parse_row(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.FORMAT):
return None
return self._parse_row_format()
- def _parse_row_format(self, match_row=False):
+ def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]:
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
@@ -1035,9 +1160,10 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
- return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
- def _parse_load_data(self):
+ return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
+
+ def _parse_load_data(self) -> exp.Expression:
local = self._match(TokenType.LOCAL)
self._match_text_seq("INPATH")
inpath = self._parse_string()
@@ -1055,7 +1181,7 @@ class Parser(metaclass=_Parser):
serde=self._match_text_seq("SERDE") and self._parse_string(),
)
- def _parse_delete(self):
+ def _parse_delete(self) -> exp.Expression:
self._match(TokenType.FROM)
return self.expression(
@@ -1065,10 +1191,10 @@ class Parser(metaclass=_Parser):
where=self._parse_where(),
)
- def _parse_update(self):
+ def _parse_update(self) -> exp.Expression:
return self.expression(
exp.Update,
- **{
+ **{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
@@ -1076,16 +1202,17 @@ class Parser(metaclass=_Parser):
},
)
- def _parse_uncache(self):
+ def _parse_uncache(self) -> exp.Expression:
if not self._match(TokenType.TABLE):
self.raise_error("Expecting TABLE after UNCACHE")
+
return self.expression(
exp.Uncache,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
)
- def _parse_cache(self):
+ def _parse_cache(self) -> exp.Expression:
lazy = self._match(TokenType.LAZY)
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
@@ -1108,21 +1235,23 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_partition(self):
+ def _parse_partition(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.PARTITION):
return None
- def parse_values():
+ def parse_values() -> exp.Property:
props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
- def _parse_value(self):
+ def _parse_value(self) -> exp.Expression:
expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
- def _parse_select(self, nested=False, table=False):
+ def _parse_select(
+ self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
+ ) -> t.Optional[exp.Expression]:
cte = self._parse_with()
if cte:
this = self._parse_statement()
@@ -1178,10 +1307,11 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren()
+
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
- return self._parse_subquery(this)
+ return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
@@ -1203,7 +1333,7 @@ class Parser(metaclass=_Parser):
return self._parse_set_operations(this)
- def _parse_with(self, skip_with_token=False):
+ def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_with_token and not self._match(TokenType.WITH):
return None
@@ -1220,7 +1350,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.With, expressions=expressions, recursive=recursive)
- def _parse_cte(self):
+ def _parse_cte(self) -> exp.Expression:
alias = self._parse_table_alias()
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
@@ -1234,7 +1364,9 @@ class Parser(metaclass=_Parser):
alias=alias,
)
- def _parse_table_alias(self, alias_tokens=None):
+ def _parse_table_alias(
+ self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ ) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
@@ -1251,15 +1383,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.TableAlias, this=alias, columns=columns)
- def _parse_subquery(self, this):
+ def _parse_subquery(
+ self, this: t.Optional[exp.Expression], parse_alias: bool = True
+ ) -> exp.Expression:
return self.expression(
exp.Subquery,
this=this,
pivots=self._parse_pivots(),
- alias=self._parse_table_alias(),
+ alias=self._parse_table_alias() if parse_alias else None,
)
- def _parse_query_modifiers(self, this):
+ def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None:
if not isinstance(this, self.MODIFIABLES):
return
@@ -1284,15 +1418,16 @@ class Parser(metaclass=_Parser):
if expression:
this.set(key, expression)
- def _parse_hint(self):
+ def _parse_hint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints)
+
return None
- def _parse_into(self):
+ def _parse_into(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.INTO):
return None
@@ -1304,14 +1439,15 @@ class Parser(metaclass=_Parser):
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
- def _parse_from(self):
+ def _parse_from(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.FROM):
return None
+
return self.expression(
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
- def _parse_lateral(self):
+ def _parse_lateral(self) -> t.Optional[exp.Expression]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
@@ -1334,6 +1470,8 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
+ table_alias: t.Optional[exp.Expression]
+
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
@@ -1354,20 +1492,24 @@ class Parser(metaclass=_Parser):
return expression
- def _parse_join_side_and_kind(self):
+ def _parse_join_side_and_kind(
+ self,
+ ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
return (
self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token=False):
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
natural, side, kind = self._parse_join_side_and_kind()
if not skip_join_token and not self._match(TokenType.JOIN):
return None
- kwargs = {"this": self._parse_table()}
+ kwargs: t.Dict[
+ str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
+ ] = {"this": self._parse_table()}
if natural:
kwargs["natural"] = True
@@ -1381,12 +1523,13 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
- return self.expression(exp.Join, **kwargs)
+ return self.expression(exp.Join, **kwargs) # type: ignore
- def _parse_index(self):
+ def _parse_index(self) -> exp.Expression:
index = self._parse_id_var()
self._match(TokenType.ON)
self._match(TokenType.TABLE) # hive
+
return self.expression(
exp.Index,
this=index,
@@ -1394,7 +1537,28 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(),
)
- def _parse_table(self, schema=False, alias_tokens=None):
+ def _parse_create_table_index(self) -> t.Optional[exp.Expression]:
+ unique = self._match(TokenType.UNIQUE)
+ primary = self._match_text_seq("PRIMARY")
+ amp = self._match_text_seq("AMP")
+ if not self._match(TokenType.INDEX):
+ return None
+ index = self._parse_id_var()
+ columns = None
+ if self._curr and self._curr.token_type == TokenType.L_PAREN:
+ columns = self._parse_wrapped_csv(self._parse_column)
+ return self.expression(
+ exp.Index,
+ this=index,
+ columns=columns,
+ unique=unique,
+ primary=primary,
+ amp=amp,
+ )
+
+ def _parse_table(
+ self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ ) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@@ -1417,7 +1581,9 @@ class Parser(metaclass=_Parser):
catalog = None
db = None
- table = (not schema and self._parse_function()) or self._parse_id_var(False)
+ table = (not schema and self._parse_function()) or self._parse_id_var(
+ any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
+ )
while self._match(TokenType.DOT):
if catalog:
@@ -1446,6 +1612,14 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
+ if self._match(TokenType.WITH):
+ this.set(
+ "hints",
+ self._parse_wrapped_csv(
+ lambda: self._parse_function() or self._parse_var(any_token=True)
+ ),
+ )
+
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@@ -1455,7 +1629,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_unnest(self):
+ def _parse_unnest(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.UNNEST):
return None
@@ -1473,7 +1647,7 @@ class Parser(metaclass=_Parser):
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
)
- def _parse_derived_table_values(self):
+ def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
@@ -1485,7 +1659,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
- def _parse_table_sample(self):
+ def _parse_table_sample(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE_SAMPLE):
return None
@@ -1533,10 +1707,10 @@ class Parser(metaclass=_Parser):
seed=seed,
)
- def _parse_pivots(self):
+ def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
- def _parse_pivot(self):
+ def _parse_pivot(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.PIVOT):
@@ -1572,16 +1746,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
- def _parse_where(self, skip_where_token=False):
+ def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
+
return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
- def _parse_group(self, skip_group_by_token=False):
+ def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
+
return self.expression(
exp.Group,
expressions=self._parse_csv(self._parse_conjunction),
@@ -1590,29 +1766,33 @@ class Parser(metaclass=_Parser):
rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(),
)
- def _parse_grouping_sets(self):
+ def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.GROUPING_SETS):
return None
+
return self._parse_wrapped_csv(self._parse_grouping_set)
- def _parse_grouping_set(self):
+ def _parse_grouping_set(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
grouping_set = self._parse_csv(self._parse_id_var)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=grouping_set)
+
return self._parse_id_var()
- def _parse_having(self, skip_having_token=False):
+ def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
- def _parse_qualify(self):
+ def _parse_qualify(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.QUALIFY):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
- def _parse_order(self, this=None, skip_order_token=False):
+ def _parse_order(
+ self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
+ ) -> t.Optional[exp.Expression]:
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
@@ -1620,12 +1800,14 @@ class Parser(metaclass=_Parser):
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
- def _parse_sort(self, token_type, exp_class):
+ def _parse_sort(
+ self, token_type: TokenType, exp_class: t.Type[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
if not self._match(token_type):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self):
+ def _parse_ordered(self) -> exp.Expression:
this = self._parse_conjunction()
self._match(TokenType.ASC)
is_desc = self._match(TokenType.DESC)
@@ -1647,7 +1829,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
- def _parse_limit(self, this=None, top=False):
+ def _parse_limit(
+ self, this: t.Optional[exp.Expression] = None, top: bool = False
+ ) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
@@ -1667,7 +1851,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_offset(self, this=None):
+ def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this
@@ -1675,7 +1859,7 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_set_operations(self, this):
+ def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
return this
@@ -1695,19 +1879,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_expression(self):
+ def _parse_expression(self) -> t.Optional[exp.Expression]:
return self._parse_alias(self._parse_conjunction())
- def _parse_conjunction(self):
+ def _parse_conjunction(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_equality, self.CONJUNCTION)
- def _parse_equality(self):
+ def _parse_equality(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_comparison, self.EQUALITY)
- def _parse_comparison(self):
+ def _parse_comparison(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_range, self.COMPARISON)
- def _parse_range(self):
+ def _parse_range(self) -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
negate = self._match(TokenType.NOT)
@@ -1730,7 +1914,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_is(self, this):
+ def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
@@ -1743,7 +1927,7 @@ class Parser(metaclass=_Parser):
)
return self.expression(exp.Not, this=this) if negate else this
- def _parse_in(self, this):
+ def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
@@ -1761,18 +1945,18 @@ class Parser(metaclass=_Parser):
return this
- def _parse_between(self, this):
+ def _parse_between(self, this: exp.Expression) -> exp.Expression:
low = self._parse_bitwise()
self._match(TokenType.AND)
high = self._parse_bitwise()
return self.expression(exp.Between, this=this, low=low, high=high)
- def _parse_escape(self, this):
+ def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.ESCAPE):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_bitwise(self):
+ def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
while True:
@@ -1795,18 +1979,18 @@ class Parser(metaclass=_Parser):
return this
- def _parse_term(self):
+ def _parse_term(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_factor, self.TERM)
- def _parse_factor(self):
+ def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.FACTOR)
- def _parse_unary(self):
+ def _parse_unary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.UNARY_PARSERS):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
- def _parse_type(self):
+ def _parse_type(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.INTERVAL):
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
@@ -1824,7 +2008,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_types(self, check_func=False):
+ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if not self._match_set(self.TYPE_TOKENS):
@@ -1875,7 +2059,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
- value = None
+ value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
@@ -1884,7 +2068,10 @@ class Parser(metaclass=_Parser):
):
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match(TokenType.WITHOUT_TIME_ZONE):
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
+ if type_token == TokenType.TIME:
+ value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
+ else:
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = maybe_func and value is None
@@ -1912,7 +2099,7 @@ class Parser(metaclass=_Parser):
nested=nested,
)
- def _parse_struct_kwargs(self):
+ def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
@@ -1921,12 +2108,12 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp.StructKwarg, this=this, expression=data_type)
- def _parse_at_time_zone(self, this):
+ def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.AT_TIME_ZONE):
return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
- def _parse_column(self):
+ def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
@@ -1943,7 +2130,8 @@ class Parser(metaclass=_Parser):
if not field:
self.raise_error("Expected type")
elif op:
- field = exp.Literal.string(self._advance() or self._prev.text)
+ self._advance()
+ field = exp.Literal.string(self._prev.text)
else:
field = self._parse_star() or self._parse_function() or self._parse_id_var()
@@ -1963,7 +2151,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_primary(self):
+ def _parse_primary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PRIMARY_PARSERS):
token_type = self._prev.token_type
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
@@ -1995,21 +2183,27 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
- this = self._parse_set_operations(self._parse_subquery(this))
+ this = self._parse_set_operations(
+ self._parse_subquery(this=this, parse_alias=False)
+ )
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
- if comments:
+
+ if this and comments:
this.comments = comments
+
return this
return None
- def _parse_field(self, any_token=False):
+ def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
- def _parse_function(self, functions=None):
+ def _parse_function(
+ self, functions: t.Optional[t.Dict[str, t.Callable]] = None
+ ) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@@ -2020,7 +2214,9 @@ class Parser(metaclass=_Parser):
if not self._next or self._next.token_type != TokenType.L_PAREN:
if token_type in self.NO_PAREN_FUNCTIONS:
- return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
+ self._advance()
+ return self.expression(self.NO_PAREN_FUNCTIONS[token_type])
+
return None
if token_type not in self.FUNC_TOKENS:
@@ -2049,7 +2245,18 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(self._parse_lambda)
if function:
- this = function(args)
+
+ # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
+ # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
+ if count_params(function) == 2:
+ params = None
+ if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
+ params = self._parse_csv(self._parse_lambda)
+
+ this = function(args, params)
+ else:
+ this = function(args)
+
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -2057,7 +2264,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
- def _parse_user_defined_function(self):
+ def _parse_user_defined_function(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
while self._match(TokenType.DOT):
@@ -2070,27 +2277,27 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
- def _parse_introducer(self, token):
+ def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
literal = self._parse_primary()
if literal:
return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
- def _parse_national(self, token):
+ def _parse_national(self, token: Token) -> exp.Expression:
return self.expression(exp.National, this=exp.Literal.string(token.text))
- def _parse_session_parameter(self):
+ def _parse_session_parameter(self) -> exp.Expression:
kind = None
this = self._parse_id_var() or self._parse_primary()
- if self._match(TokenType.DOT):
+ if this and self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(exp.SessionParameter, this=this, kind=kind)
- def _parse_udf_kwarg(self):
+ def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
kind = self._parse_types()
@@ -2099,7 +2306,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
- def _parse_lambda(self):
+ def _parse_lambda(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
@@ -2115,6 +2322,8 @@ class Parser(metaclass=_Parser):
self._retreat(index)
+ this: t.Optional[exp.Expression]
+
if self._match(TokenType.DISTINCT):
this = self.expression(
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
@@ -2129,7 +2338,7 @@ class Parser(metaclass=_Parser):
return self._parse_limit(self._parse_order(this))
- def _parse_schema(self, this=None):
+ def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
self._retreat(index)
@@ -2140,14 +2349,15 @@ class Parser(metaclass=_Parser):
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
+
+ if isinstance(this, exp.Literal):
+ this = this.name
+
return self.expression(exp.Schema, this=this, expressions=args)
- def _parse_column_def(self, this):
+ def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
kind = self._parse_types()
- if not kind:
- return this
-
constraints = []
while True:
constraint = self._parse_column_constraint()
@@ -2155,9 +2365,12 @@ class Parser(metaclass=_Parser):
break
constraints.append(constraint)
+ if not kind and not constraints:
+ return this
+
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
- def _parse_column_constraint(self):
+ def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
this = self._parse_references()
if this:
@@ -2166,6 +2379,8 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
+ kind: exp.Expression
+
if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
@@ -2202,7 +2417,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
- def _parse_constraint(self):
+ def _parse_constraint(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.CONSTRAINT):
return self._parse_unnamed_constraint()
@@ -2217,24 +2432,25 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Constraint, this=this, expressions=expressions)
- def _parse_unnamed_constraint(self):
+ def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]:
if not self._match_set(self.CONSTRAINT_PARSERS):
return None
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
- def _parse_unique(self):
+ def _parse_unique(self) -> exp.Expression:
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
- def _parse_references(self):
+ def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES):
return None
+
return self.expression(
exp.Reference,
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(),
)
- def _parse_foreign_key(self):
+ def _parse_foreign_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {}
@@ -2260,13 +2476,15 @@ class Parser(metaclass=_Parser):
exp.ForeignKey,
expressions=expressions,
reference=reference,
- **options,
+ **options, # type: ignore
)
- def _parse_bracket(self, this):
+ def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET):
return this
+ expressions: t.List[t.Optional[exp.Expression]]
+
if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
else:
@@ -2284,12 +2502,12 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments
return self._parse_bracket(this)
- def _parse_slice(self, this):
+ def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return this
- def _parse_case(self):
+ def _parse_case(self) -> t.Optional[exp.Expression]:
ifs = []
default = None
@@ -2311,7 +2529,7 @@ class Parser(metaclass=_Parser):
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
)
- def _parse_if(self):
+ def _parse_if(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
args = self._parse_csv(self._parse_conjunction)
this = exp.If.from_arg_list(args)
@@ -2324,9 +2542,10 @@ class Parser(metaclass=_Parser):
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
self._match(TokenType.END)
this = self.expression(exp.If, this=condition, true=true, false=false)
+
return self._parse_window(this)
- def _parse_extract(self):
+ def _parse_extract(self) -> exp.Expression:
this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
@@ -2337,7 +2556,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
- def _parse_cast(self, strict):
+ def _parse_cast(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
@@ -2353,7 +2572,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_string_agg(self):
+ def _parse_string_agg(self) -> exp.Expression:
+ expression: t.Optional[exp.Expression]
+
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
@@ -2380,8 +2601,10 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
- def _parse_convert(self, strict):
+ def _parse_convert(self, strict: bool) -> exp.Expression:
+ to: t.Optional[exp.Expression]
this = self._parse_column()
+
if self._match(TokenType.USING):
to = self.expression(exp.CharacterSet, this=self._parse_var())
elif self._match(TokenType.COMMA):
@@ -2390,7 +2613,7 @@ class Parser(metaclass=_Parser):
to = None
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_position(self):
+ def _parse_position(self) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
@@ -2402,11 +2625,11 @@ class Parser(metaclass=_Parser):
return this
- def _parse_join_hint(self, func_name):
+ def _parse_join_hint(self, func_name: str) -> exp.Expression:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
- def _parse_substring(self):
+ def _parse_substring(self) -> exp.Expression:
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@@ -2422,7 +2645,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_trim(self):
+ def _parse_trim(self) -> exp.Expression:
# https://www.w3resource.com/sql/character-functions/trim.php
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
@@ -2450,13 +2673,15 @@ class Parser(metaclass=_Parser):
collation=collation,
)
- def _parse_window_clause(self):
+ def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
- def _parse_named_window(self):
+ def _parse_named_window(self) -> t.Optional[exp.Expression]:
return self._parse_window(self._parse_id_var(), alias=True)
- def _parse_window(self, this, alias=False):
+ def _parse_window(
+ self, this: t.Optional[exp.Expression], alias: bool = False
+ ) -> t.Optional[exp.Expression]:
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
this = self.expression(exp.Filter, this=this, expression=where)
@@ -2495,7 +2720,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
- alias = self._parse_id_var(False)
+ window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
partition = None
if self._match(TokenType.PARTITION_BY):
@@ -2529,10 +2754,10 @@ class Parser(metaclass=_Parser):
partition_by=partition,
order=order,
spec=spec,
- alias=alias,
+ alias=window_alias,
)
- def _parse_window_spec(self):
+ def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
return {
@@ -2543,7 +2768,9 @@ class Parser(metaclass=_Parser):
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
- def _parse_alias(self, this, explicit=False):
+ def _parse_alias(
+ self, this: t.Optional[exp.Expression], explicit: bool = False
+ ) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
if explicit and not any_token:
@@ -2565,63 +2792,74 @@ class Parser(metaclass=_Parser):
return this
- def _parse_id_var(self, any_token=True, tokens=None):
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ prefix_tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
identifier = self._parse_identifier()
if identifier:
return identifier
+ prefix = ""
+
+ if prefix_tokens:
+ while self._match_set(prefix_tokens):
+ prefix += self._prev.text
+
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
- return exp.Identifier(this=self._prev.text, quoted=False)
+ return exp.Identifier(this=prefix + self._prev.text, quoted=False)
return None
- def _parse_string(self):
+ def _parse_string(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STRING):
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
- def _parse_number(self):
+ def _parse_number(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NUMBER):
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
return self._parse_placeholder()
- def _parse_identifier(self):
+ def _parse_identifier(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.IDENTIFIER):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
- def _parse_var(self, any_token=False):
+ def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]:
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
- def _advance_any(self):
+ def _advance_any(self) -> t.Optional[Token]:
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
return self._prev
return None
- def _parse_var_or_string(self):
+ def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
return self._parse_var() or self._parse_string()
- def _parse_null(self):
+ def _parse_null(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NULL):
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None
- def _parse_boolean(self):
+ def _parse_boolean(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.TRUE):
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None
- def _parse_star(self):
+ def _parse_star(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STAR):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
- def _parse_placeholder(self):
+ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
@@ -2630,18 +2868,20 @@ class Parser(metaclass=_Parser):
self._advance(-1)
return None
- def _parse_except(self):
+ def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.EXCEPT):
return None
return self._parse_wrapped_id_vars()
- def _parse_replace(self):
+ def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
return None
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
- def _parse_csv(self, parse_method, sep=TokenType.COMMA):
+ def _parse_csv(
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ ) -> t.List[t.Optional[exp.Expression]]:
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
@@ -2655,7 +2895,9 @@ class Parser(metaclass=_Parser):
return items
- def _parse_tokens(self, parse_method, expressions):
+ def _parse_tokens(
+ self, parse_method: t.Callable, expressions: t.Dict
+ ) -> t.Optional[exp.Expression]:
this = parse_method()
while self._match_set(expressions):
@@ -2668,22 +2910,29 @@ class Parser(metaclass=_Parser):
return this
- def _parse_wrapped_id_vars(self):
+ def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]:
return self._parse_wrapped_csv(self._parse_id_var)
- def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
+ def _parse_wrapped_csv(
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ ) -> t.List[t.Optional[exp.Expression]]:
return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
- def _parse_wrapped(self, parse_method):
+ def _parse_wrapped(self, parse_method: t.Callable) -> t.Any:
self._match_l_paren()
parse_result = parse_method()
self._match_r_paren()
return parse_result
- def _parse_select_or_expression(self):
+ def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_expression()
- def _parse_transaction(self):
+ def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
+ return self._parse_set_operations(
+ self._parse_select(nested=True, parse_subquery_alias=False)
+ )
+
+ def _parse_transaction(self) -> exp.Expression:
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
@@ -2703,7 +2952,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
- def _parse_commit_or_rollback(self):
+ def _parse_commit_or_rollback(self) -> exp.Expression:
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@@ -2722,27 +2971,30 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain)
- def _parse_add_column(self):
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("ADD"):
return None
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
- expression.set("exists", exists_column)
+
+ if expression:
+ expression.set("exists", exists_column)
+
return expression
- def _parse_drop_column(self):
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
- def _parse_alter(self):
+ def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return None
exists = self._parse_exists()
this = self._parse_table(schema=True)
- actions = None
+ actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
if self._match_text_seq("ADD", advance=False):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
@@ -2770,24 +3022,24 @@ class Parser(metaclass=_Parser):
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
- def _parse_show(self):
- parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
+ def _parse_show(self) -> t.Optional[exp.Expression]:
+ parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
- def _default_parse_set_item(self):
+ def _default_parse_set_item(self) -> exp.Expression:
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)
- def _parse_set_item(self):
- parser = self._find_parser(self.SET_PARSERS, self._set_trie)
+ def _parse_set_item(self) -> t.Optional[exp.Expression]:
+ parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
return parser(self) if parser else self._default_parse_set_item()
- def _parse_merge(self):
+ def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO)
target = self._parse_table(schema=True)
@@ -2835,10 +3087,12 @@ class Parser(metaclass=_Parser):
expressions=whens,
)
- def _parse_set(self):
+ def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
- def _find_parser(self, parsers, trie):
+ def _find_parser(
+ self, parsers: t.Dict[str, t.Callable], trie: t.Dict
+ ) -> t.Optional[t.Callable]:
index = self._index
this = []
while True: