diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-17 10:32:12 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-17 10:32:12 +0000 |
commit | 244a05de60c9417daab9528b51788c3d2a00dc5f (patch) | |
tree | 89a9c82aa41d397e1b81c320ad7a287b6c80f313 /sqlglot/parser.py | |
parent | Adding upstream version 10.4.2. (diff) | |
download | sqlglot-244a05de60c9417daab9528b51788c3d2a00dc5f.tar.xz sqlglot-244a05de60c9417daab9528b51788c3d2a00dc5f.zip |
Adding upstream version 10.5.2.upstream/10.5.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 652 |
1 files changed, 453 insertions, 199 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 308f363..bd95db8 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,13 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get +from sqlglot.helper import ( + apply_index_offset, + count_params, + ensure_collection, + ensure_list, + seq_get, +) from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -54,7 +60,7 @@ class Parser(metaclass=_Parser): Default: "nulls_are_small" """ - FUNCTIONS = { + FUNCTIONS: t.Dict[str, t.Callable] = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( this=seq_get(args, 0), @@ -106,6 +112,7 @@ class Parser(metaclass=_Parser): TokenType.JSON, TokenType.JSONB, TokenType.INTERVAL, + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -164,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.DELETE, TokenType.DESCRIBE, TokenType.DETERMINISTIC, + TokenType.DIV, TokenType.DISTKEY, TokenType.DISTSTYLE, TokenType.EXECUTE, @@ -252,6 +260,7 @@ class Parser(metaclass=_Parser): TokenType.FIRST, TokenType.FORMAT, TokenType.IDENTIFIER, + TokenType.INDEX, TokenType.ISNULL, TokenType.MERGE, TokenType.OFFSET, @@ -312,6 +321,7 @@ class Parser(metaclass=_Parser): } TIMESTAMPS = { + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -387,6 +397,7 @@ class Parser(metaclass=_Parser): } EXPRESSION_PARSERS = { + exp.Column: lambda self: self._parse_column(), exp.DataType: lambda self: self._parse_types(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), @@ -419,6 +430,7 @@ class Parser(metaclass=_Parser): TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.DESC: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.END: lambda self: self._parse_commit_or_rollback(), @@ -583,6 +595,11 @@ class Parser(metaclass=_Parser): TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + + # allows tables to have special tokens as prefixes + TABLE_PREFIX_TOKENS: t.Set[TokenType] = set() + STRICT_CAST = True __slots__ = ( @@ -608,13 +625,13 @@ class Parser(metaclass=_Parser): def __init__( self, - error_level=None, - error_message_context=100, - index_offset=0, - unnest_column_only=False, - alias_post_tablesample=False, - max_errors=3, - null_ordering=None, + error_level: t.Optional[ErrorLevel] = None, + error_message_context: int = 100, + index_offset: int = 0, + unnest_column_only: bool = False, + alias_post_tablesample: bool = False, + max_errors: int = 3, + null_ordering: t.Optional[str] = None, ): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context @@ -636,23 +653,43 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def parse(self, raw_tokens, sql=None): + def parse( + self, raw_tokens: t.List[Token], sql: t.Optional[str] = None + ) -> t.List[t.Optional[exp.Expression]]: """ - Parses the given list of tokens and returns a list of syntax trees, one tree + Parses a list of tokens and returns a list of syntax trees, one tree per parsed SQL statement. - Args - raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`). - sql (str): the original SQL string. Used to produce helpful debug messages. + Args: + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. - Returns - the list of syntax trees (:class:`~sqlglot.expressions.Expression`). + Returns: + The list of syntax trees. """ return self._parse( parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql ) - def parse_into(self, expression_types, raw_tokens, sql=None): + def parse_into( + self, + expression_types: str | exp.Expression | t.Collection[exp.Expression | str], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: + """ + Parses a list of tokens into a given Expression type. If a collection of Expression + types is given instead, this method will try to parse the token list into each one + of them, stopping at the first for which the parsing succeeds. + + Args: + expression_types: the expression type(s) to try and parse the token list into. + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. + + Returns: + The target Expression. + """ errors = [] for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) @@ -668,7 +705,12 @@ class Parser(metaclass=_Parser): errors=merge_errors(errors), ) from errors[-1] - def _parse(self, parse_method, raw_tokens, sql=None): + def _parse( + self, + parse_method: t.Callable[[Parser], t.Optional[exp.Expression]], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: self.reset() self.sql = sql or "" total = len(raw_tokens) @@ -686,6 +728,7 @@ class Parser(metaclass=_Parser): self._index = -1 self._tokens = tokens self._advance() + expressions.append(parse_method(self)) if self._index < len(self._tokens): @@ -695,7 +738,10 @@ class Parser(metaclass=_Parser): return expressions - def check_errors(self): + def check_errors(self) -> None: + """ + Logs or raises any found errors, depending on the chosen error level setting. + """ if self.error_level == ErrorLevel.WARN: for error in self.errors: logger.error(str(error)) @@ -705,13 +751,18 @@ class Parser(metaclass=_Parser): errors=merge_errors(self.errors), ) - def raise_error(self, message, token=None): + def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: + """ + Appends an error in the list of recorded errors or raises it, depending on the chosen + error level setting. + """ token = token or self._curr or self._prev or Token.string("") start = self._find_token(token, self.sql) end = start + len(token.text) start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] + error = ParseError.new( f"{message}. Line {token.line}, Col: {token.col}.\n" f" {start_context}\033[4m{highlight}\033[0m{end_context}", @@ -722,11 +773,26 @@ class Parser(metaclass=_Parser): highlight=highlight, end_context=end_context, ) + if self.error_level == ErrorLevel.IMMEDIATE: raise error + self.errors.append(error) - def expression(self, exp_class, comments=None, **kwargs): + def expression( + self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> exp.Expression: + """ + Creates a new, validated Expression. + + Args: + exp_class: the expression class to instantiate. + comments: an optional list of comments to attach to the expression. + kwargs: the arguments to set for the expression along with their respective values. + + Returns: + The target expression. + """ instance = exp_class(**kwargs) if self._prev_comments: instance.comments = self._prev_comments @@ -736,7 +802,17 @@ class Parser(metaclass=_Parser): self.validate_expression(instance) return instance - def validate_expression(self, expression, args=None): + def validate_expression( + self, expression: exp.Expression, args: t.Optional[t.List] = None + ) -> None: + """ + Validates an already instantiated expression, making sure that all its mandatory arguments + are set. + + Args: + expression: the expression to validate. + args: an optional list of items that was used to instantiate the expression, if it's a Func. + """ if self.error_level == ErrorLevel.IGNORE: return @@ -748,13 +824,18 @@ class Parser(metaclass=_Parser): if mandatory and (v is None or (isinstance(v, list) and not v)): self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: + if ( + args + and isinstance(expression, exp.Func) + and len(args) > len(expression.arg_types) + and not expression.is_var_len_args + ): self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" ) - def _find_token(self, token, sql): + def _find_token(self, token: Token, sql: str) -> int: line = 1 col = 1 index = 0 @@ -769,7 +850,7 @@ class Parser(metaclass=_Parser): return index - def _advance(self, times=1): + def _advance(self, times: int = 1) -> None: self._index += times self._curr = seq_get(self._tokens, self._index) self._next = seq_get(self._tokens, self._index + 1) @@ -780,10 +861,10 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def _retreat(self, index): + def _retreat(self, index: int) -> None: self._advance(index - self._index) - def _parse_statement(self): + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -803,7 +884,7 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self, default_kind=None): + def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text @@ -812,7 +893,7 @@ class Parser(metaclass=_Parser): kind = default_kind else: self.raise_error(f"Expected {self.CREATABLES}") - return + return None return self.expression( exp.Drop, @@ -824,14 +905,14 @@ class Parser(metaclass=_Parser): cascade=self._match(TokenType.CASCADE), ) - def _parse_exists(self, not_=False): + def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: return ( self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) ) - def _parse_create(self): + def _parse_create(self) -> t.Optional[exp.Expression]: replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) transient = self._match_text_seq("TRANSIENT") @@ -846,12 +927,16 @@ class Parser(metaclass=_Parser): if not create_token: self.raise_error(f"Expected {self.CREATABLES}") - return + return None exists = self._parse_exists(not_=True) this = None expression = None properties = None + data = None + statistics = None + no_primary_index = None + indexes = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function() @@ -868,7 +953,28 @@ class Parser(metaclass=_Parser): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_select(nested=True) + expression = self._parse_ddl_select() + + if create_token.token_type == TokenType.TABLE: + if self._match_text_seq("WITH", "DATA"): + data = True + elif self._match_text_seq("WITH", "NO", "DATA"): + data = False + + if self._match_text_seq("AND", "STATISTICS"): + statistics = True + elif self._match_text_seq("AND", "NO", "STATISTICS"): + statistics = False + + no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX") + + indexes = [] + while True: + index = self._parse_create_table_index() + if not index: + break + else: + indexes.append(index) return self.expression( exp.Create, @@ -883,9 +989,13 @@ class Parser(metaclass=_Parser): replace=replace, unique=unique, materialized=materialized, + data=data, + statistics=statistics, + no_primary_index=no_primary_index, + indexes=indexes, ) - def _parse_property(self): + def _parse_property(self) -> t.Optional[exp.Expression]: if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) @@ -906,7 +1016,7 @@ class Parser(metaclass=_Parser): return None - def _parse_property_assignment(self, exp_class): + def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) return self.expression( @@ -914,42 +1024,50 @@ class Parser(metaclass=_Parser): this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_partitioned_by(self): + def _parse_partitioned_by(self) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_distkey(self): + def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - def _parse_create_like(self): + def _parse_create_like(self) -> t.Optional[exp.Expression]: table = self._parse_table(schema=True) options = [] while self._match_texts(("INCLUDING", "EXCLUDING")): + this = self._prev.text.upper() + id_var = self._parse_id_var() + + if not id_var: + return None + options.append( self.expression( exp.Property, - this=self._prev.text.upper(), - value=exp.Var(this=self._parse_id_var().this.upper()), + this=this, + value=exp.Var(this=id_var.this.upper()), ) ) return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_sortkey(self, compound=False): + def _parse_sortkey(self, compound: bool = False) -> exp.Expression: return self.expression( exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound ) - def _parse_character_set(self, default=False): + def _parse_character_set(self, default: bool = False) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) - def _parse_returns(self): + def _parse_returns(self) -> exp.Expression: + value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) + if is_table: if self._match(TokenType.LT): value = self.expression( @@ -960,13 +1078,13 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema("TABLE") + value = self._parse_schema(exp.Literal.string("TABLE")) else: value = self._parse_types() return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_properties(self): + def _parse_properties(self) -> t.Optional[exp.Expression]: properties = [] while True: @@ -978,15 +1096,21 @@ class Parser(metaclass=_Parser): if properties: return self.expression(exp.Properties, expressions=properties) + return None - def _parse_describe(self): - self._match(TokenType.TABLE) - return self.expression(exp.Describe, this=self._parse_id_var()) + def _parse_describe(self) -> exp.Expression: + kind = self._match_set(self.CREATABLES) and self._prev.text + this = self._parse_table() - def _parse_insert(self): + return self.expression(exp.Describe, this=this, kind=kind) + + def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) + + this: t.Optional[exp.Expression] + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, @@ -998,21 +1122,22 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) self._match(TokenType.TABLE) this = self._parse_table(schema=True) + return self.expression( exp.Insert, this=this, exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(nested=True), + expression=self._parse_ddl_select(), overwrite=overwrite, ) - def _parse_row(self): + def _parse_row(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FORMAT): return None return self._parse_row_format() - def _parse_row_format(self, match_row=False): + def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]: if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None @@ -1035,9 +1160,10 @@ class Parser(metaclass=_Parser): kwargs["lines"] = self._parse_string() if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() - return self.expression(exp.RowFormatDelimitedProperty, **kwargs) - def _parse_load_data(self): + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore + + def _parse_load_data(self) -> exp.Expression: local = self._match(TokenType.LOCAL) self._match_text_seq("INPATH") inpath = self._parse_string() @@ -1055,7 +1181,7 @@ class Parser(metaclass=_Parser): serde=self._match_text_seq("SERDE") and self._parse_string(), ) - def _parse_delete(self): + def _parse_delete(self) -> exp.Expression: self._match(TokenType.FROM) return self.expression( @@ -1065,10 +1191,10 @@ class Parser(metaclass=_Parser): where=self._parse_where(), ) - def _parse_update(self): + def _parse_update(self) -> exp.Expression: return self.expression( exp.Update, - **{ + **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), @@ -1076,16 +1202,17 @@ class Parser(metaclass=_Parser): }, ) - def _parse_uncache(self): + def _parse_uncache(self) -> exp.Expression: if not self._match(TokenType.TABLE): self.raise_error("Expecting TABLE after UNCACHE") + return self.expression( exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True), ) - def _parse_cache(self): + def _parse_cache(self) -> exp.Expression: lazy = self._match(TokenType.LAZY) self._match(TokenType.TABLE) table = self._parse_table(schema=True) @@ -1108,21 +1235,23 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_partition(self): + def _parse_partition(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.PARTITION): return None - def parse_values(): + def parse_values() -> exp.Property: props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) - def _parse_value(self): + def _parse_value(self) -> exp.Expression: expressions = self._parse_wrapped_csv(self._parse_conjunction) return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, nested=False, table=False): + def _parse_select( + self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True + ) -> t.Optional[exp.Expression]: cte = self._parse_with() if cte: this = self._parse_statement() @@ -1178,10 +1307,11 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(this) this = self._parse_set_operations(this) self._match_r_paren() + # early return so that subquery unions aren't parsed again # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 # Union ALL should be a property of the top select node, not the subquery - return self._parse_subquery(this) + return self._parse_subquery(this, parse_alias=parse_subquery_alias) elif self._match(TokenType.VALUES): if self._curr.token_type == TokenType.L_PAREN: # We don't consume the left paren because it's consumed in _parse_value @@ -1203,7 +1333,7 @@ class Parser(metaclass=_Parser): return self._parse_set_operations(this) - def _parse_with(self, skip_with_token=False): + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]: if not skip_with_token and not self._match(TokenType.WITH): return None @@ -1220,7 +1350,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.With, expressions=expressions, recursive=recursive) - def _parse_cte(self): + def _parse_cte(self) -> exp.Expression: alias = self._parse_table_alias() if not alias or not alias.this: self.raise_error("Expected CTE to have alias") @@ -1234,7 +1364,9 @@ class Parser(metaclass=_Parser): alias=alias, ) - def _parse_table_alias(self, alias_tokens=None): + def _parse_table_alias( + self, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS @@ -1251,15 +1383,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.TableAlias, this=alias, columns=columns) - def _parse_subquery(self, this): + def _parse_subquery( + self, this: t.Optional[exp.Expression], parse_alias: bool = True + ) -> exp.Expression: return self.expression( exp.Subquery, this=this, pivots=self._parse_pivots(), - alias=self._parse_table_alias(), + alias=self._parse_table_alias() if parse_alias else None, ) - def _parse_query_modifiers(self, this): + def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None: if not isinstance(this, self.MODIFIABLES): return @@ -1284,15 +1418,16 @@ class Parser(metaclass=_Parser): if expression: this.set(key, expression) - def _parse_hint(self): + def _parse_hint(self) -> t.Optional[exp.Expression]: if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") return self.expression(exp.Hint, expressions=hints) + return None - def _parse_into(self): + def _parse_into(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.INTO): return None @@ -1304,14 +1439,15 @@ class Parser(metaclass=_Parser): exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged ) - def _parse_from(self): + def _parse_from(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FROM): return None + return self.expression( exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) ) - def _parse_lateral(self): + def _parse_lateral(self) -> t.Optional[exp.Expression]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -1334,6 +1470,8 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) + table_alias: t.Optional[exp.Expression] + if view: table = self._parse_id_var(any_token=False) columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] @@ -1354,20 +1492,24 @@ class Parser(metaclass=_Parser): return expression - def _parse_join_side_and_kind(self): + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: return ( self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token=False): + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: natural, side, kind = self._parse_join_side_and_kind() if not skip_join_token and not self._match(TokenType.JOIN): return None - kwargs = {"this": self._parse_table()} + kwargs: t.Dict[ + str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] + ] = {"this": self._parse_table()} if natural: kwargs["natural"] = True @@ -1381,12 +1523,13 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() - return self.expression(exp.Join, **kwargs) + return self.expression(exp.Join, **kwargs) # type: ignore - def _parse_index(self): + def _parse_index(self) -> exp.Expression: index = self._parse_id_var() self._match(TokenType.ON) self._match(TokenType.TABLE) # hive + return self.expression( exp.Index, this=index, @@ -1394,7 +1537,28 @@ class Parser(metaclass=_Parser): columns=self._parse_expression(), ) - def _parse_table(self, schema=False, alias_tokens=None): + def _parse_create_table_index(self) -> t.Optional[exp.Expression]: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): + return None + index = self._parse_id_var() + columns = None + if self._curr and self._curr.token_type == TokenType.L_PAREN: + columns = self._parse_wrapped_csv(self._parse_column) + return self.expression( + exp.Index, + this=index, + columns=columns, + unique=unique, + primary=primary, + amp=amp, + ) + + def _parse_table( + self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -1417,7 +1581,9 @@ class Parser(metaclass=_Parser): catalog = None db = None - table = (not schema and self._parse_function()) or self._parse_id_var(False) + table = (not schema and self._parse_function()) or self._parse_id_var( + any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS + ) while self._match(TokenType.DOT): if catalog: @@ -1446,6 +1612,14 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + if self._match(TokenType.WITH): + this.set( + "hints", + self._parse_wrapped_csv( + lambda: self._parse_function() or self._parse_var(any_token=True) + ), + ) + if not self.alias_post_tablesample: table_sample = self._parse_table_sample() @@ -1455,7 +1629,7 @@ class Parser(metaclass=_Parser): return this - def _parse_unnest(self): + def _parse_unnest(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.UNNEST): return None @@ -1473,7 +1647,7 @@ class Parser(metaclass=_Parser): exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias ) - def _parse_derived_table_values(self): + def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) if not is_derived and not self._match(TokenType.VALUES): return None @@ -1485,7 +1659,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) - def _parse_table_sample(self): + def _parse_table_sample(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE_SAMPLE): return None @@ -1533,10 +1707,10 @@ class Parser(metaclass=_Parser): seed=seed, ) - def _parse_pivots(self): + def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: return list(iter(self._parse_pivot, None)) - def _parse_pivot(self): + def _parse_pivot(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.PIVOT): @@ -1572,16 +1746,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) - def _parse_where(self, skip_where_token=False): + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): return None + return self.expression( exp.Where, comments=self._prev_comments, this=self._parse_conjunction() ) - def _parse_group(self, skip_group_by_token=False): + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None + return self.expression( exp.Group, expressions=self._parse_csv(self._parse_conjunction), @@ -1590,29 +1766,33 @@ class Parser(metaclass=_Parser): rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), ) - def _parse_grouping_sets(self): + def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.GROUPING_SETS): return None + return self._parse_wrapped_csv(self._parse_grouping_set) - def _parse_grouping_set(self): + def _parse_grouping_set(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): grouping_set = self._parse_csv(self._parse_id_var) self._match_r_paren() return self.expression(exp.Tuple, expressions=grouping_set) + return self._parse_id_var() - def _parse_having(self, skip_having_token=False): + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) - def _parse_qualify(self): + def _parse_qualify(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.QUALIFY): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) - def _parse_order(self, this=None, skip_order_token=False): + def _parse_order( + self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False + ) -> t.Optional[exp.Expression]: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this @@ -1620,12 +1800,14 @@ class Parser(metaclass=_Parser): exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) ) - def _parse_sort(self, token_type, exp_class): + def _parse_sort( + self, token_type: TokenType, exp_class: t.Type[exp.Expression] + ) -> t.Optional[exp.Expression]: if not self._match(token_type): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self): + def _parse_ordered(self) -> exp.Expression: this = self._parse_conjunction() self._match(TokenType.ASC) is_desc = self._match(TokenType.DESC) @@ -1647,7 +1829,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) - def _parse_limit(self, this=None, top=False): + def _parse_limit( + self, this: t.Optional[exp.Expression] = None, top: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) @@ -1667,7 +1851,7 @@ class Parser(metaclass=_Parser): return this - def _parse_offset(self, this=None): + def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this @@ -1675,7 +1859,7 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_set_operations(self, this): + def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): return this @@ -1695,19 +1879,19 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_expression(self): + def _parse_expression(self) -> t.Optional[exp.Expression]: return self._parse_alias(self._parse_conjunction()) - def _parse_conjunction(self): + def _parse_conjunction(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_equality, self.CONJUNCTION) - def _parse_equality(self): + def _parse_equality(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_comparison, self.EQUALITY) - def _parse_comparison(self): + def _parse_comparison(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_range, self.COMPARISON) - def _parse_range(self): + def _parse_range(self) -> t.Optional[exp.Expression]: this = self._parse_bitwise() negate = self._match(TokenType.NOT) @@ -1730,7 +1914,7 @@ class Parser(metaclass=_Parser): return this - def _parse_is(self, this): + def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression: negate = self._match(TokenType.NOT) if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ @@ -1743,7 +1927,7 @@ class Parser(metaclass=_Parser): ) return self.expression(exp.Not, this=this) if negate else this - def _parse_in(self, this): + def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) @@ -1761,18 +1945,18 @@ class Parser(metaclass=_Parser): return this - def _parse_between(self, this): + def _parse_between(self, this: exp.Expression) -> exp.Expression: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() return self.expression(exp.Between, this=this, low=low, high=high) - def _parse_escape(self, this): + def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.ESCAPE): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_bitwise(self): + def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() while True: @@ -1795,18 +1979,18 @@ class Parser(metaclass=_Parser): return this - def _parse_term(self): + def _parse_term(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_factor, self.TERM) - def _parse_factor(self): + def _parse_factor(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_unary, self.FACTOR) - def _parse_unary(self): + def _parse_unary(self) -> t.Optional[exp.Expression]: if self._match_set(self.UNARY_PARSERS): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self): + def _parse_type(self) -> t.Optional[exp.Expression]: if self._match(TokenType.INTERVAL): return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) @@ -1824,7 +2008,7 @@ class Parser(metaclass=_Parser): return this - def _parse_types(self, check_func=False): + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1875,7 +2059,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") - value = None + value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) @@ -1884,7 +2068,10 @@ class Parser(metaclass=_Parser): ): value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match(TokenType.WITHOUT_TIME_ZONE): - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + if type_token == TokenType.TIME: + value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) + else: + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) maybe_func = maybe_func and value is None @@ -1912,7 +2099,7 @@ class Parser(metaclass=_Parser): nested=nested, ) - def _parse_struct_kwargs(self): + def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() @@ -1921,12 +2108,12 @@ class Parser(metaclass=_Parser): return None return self.expression(exp.StructKwarg, this=this, expression=data_type) - def _parse_at_time_zone(self, this): + def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.AT_TIME_ZONE): return this return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) - def _parse_column(self): + def _parse_column(self) -> t.Optional[exp.Expression]: this = self._parse_field() if isinstance(this, exp.Identifier): this = self.expression(exp.Column, this=this) @@ -1943,7 +2130,8 @@ class Parser(metaclass=_Parser): if not field: self.raise_error("Expected type") elif op: - field = exp.Literal.string(self._advance() or self._prev.text) + self._advance() + field = exp.Literal.string(self._prev.text) else: field = self._parse_star() or self._parse_function() or self._parse_id_var() @@ -1963,7 +2151,7 @@ class Parser(metaclass=_Parser): return this - def _parse_primary(self): + def _parse_primary(self) -> t.Optional[exp.Expression]: if self._match_set(self.PRIMARY_PARSERS): token_type = self._prev.token_type primary = self.PRIMARY_PARSERS[token_type](self, self._prev) @@ -1995,21 +2183,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() if isinstance(this, exp.Subqueryable): - this = self._parse_set_operations(self._parse_subquery(this)) + this = self._parse_set_operations( + self._parse_subquery(this=this, parse_alias=False) + ) elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comments: + + if this and comments: this.comments = comments + return this return None - def _parse_field(self, any_token=False): + def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]: return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self, functions=None): + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None + ) -> t.Optional[exp.Expression]: if not self._curr: return None @@ -2020,7 +2214,9 @@ class Parser(metaclass=_Parser): if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) + self._advance() + return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) + return None if token_type not in self.FUNC_TOKENS: @@ -2049,7 +2245,18 @@ class Parser(metaclass=_Parser): args = self._parse_csv(self._parse_lambda) if function: - this = function(args) + + # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the + # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. + if count_params(function) == 2: + params = None + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + + this = function(args, params) + else: + this = function(args) + self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -2057,7 +2264,7 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) - def _parse_user_defined_function(self): + def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() while self._match(TokenType.DOT): @@ -2070,27 +2277,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) - def _parse_introducer(self, token): + def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: literal = self._parse_primary() if literal: return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) - def _parse_national(self, token): + def _parse_national(self, token: Token) -> exp.Expression: return self.expression(exp.National, this=exp.Literal.string(token.text)) - def _parse_session_parameter(self): + def _parse_session_parameter(self) -> exp.Expression: kind = None this = self._parse_id_var() or self._parse_primary() - if self._match(TokenType.DOT): + if this and self._match(TokenType.DOT): kind = this.name this = self._parse_var() or self._parse_primary() return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_udf_kwarg(self): + def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() kind = self._parse_types() @@ -2099,7 +2306,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind) - def _parse_lambda(self): + def _parse_lambda(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): @@ -2115,6 +2322,8 @@ class Parser(metaclass=_Parser): self._retreat(index) + this: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): this = self.expression( exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) @@ -2129,7 +2338,7 @@ class Parser(metaclass=_Parser): return self._parse_limit(self._parse_order(this)) - def _parse_schema(self, this=None): + def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): self._retreat(index) @@ -2140,14 +2349,15 @@ class Parser(metaclass=_Parser): or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() + + if isinstance(this, exp.Literal): + this = this.name + return self.expression(exp.Schema, this=this, expressions=args) - def _parse_column_def(self, this): + def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: kind = self._parse_types() - if not kind: - return this - constraints = [] while True: constraint = self._parse_column_constraint() @@ -2155,9 +2365,12 @@ class Parser(metaclass=_Parser): break constraints.append(constraint) + if not kind and not constraints: + return this + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_column_constraint(self): + def _parse_column_constraint(self) -> t.Optional[exp.Expression]: this = self._parse_references() if this: @@ -2166,6 +2379,8 @@ class Parser(metaclass=_Parser): if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() + kind: exp.Expression + if self._match(TokenType.AUTO_INCREMENT): kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): @@ -2202,7 +2417,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnConstraint, this=this, kind=kind) - def _parse_constraint(self): + def _parse_constraint(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.CONSTRAINT): return self._parse_unnamed_constraint() @@ -2217,24 +2432,25 @@ class Parser(metaclass=_Parser): return self.expression(exp.Constraint, this=this, expressions=expressions) - def _parse_unnamed_constraint(self): + def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]: if not self._match_set(self.CONSTRAINT_PARSERS): return None return self.CONSTRAINT_PARSERS[self._prev.token_type](self) - def _parse_unique(self): + def _parse_unique(self) -> exp.Expression: return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) - def _parse_references(self): + def _parse_references(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.REFERENCES): return None + return self.expression( exp.Reference, this=self._parse_id_var(), expressions=self._parse_wrapped_id_vars(), ) - def _parse_foreign_key(self): + def _parse_foreign_key(self) -> exp.Expression: expressions = self._parse_wrapped_id_vars() reference = self._parse_references() options = {} @@ -2260,13 +2476,15 @@ class Parser(metaclass=_Parser): exp.ForeignKey, expressions=expressions, reference=reference, - **options, + **options, # type: ignore ) - def _parse_bracket(self, this): + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.L_BRACKET): return this + expressions: t.List[t.Optional[exp.Expression]] + if self._match(TokenType.COLON): expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] else: @@ -2284,12 +2502,12 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) - def _parse_slice(self, this): + def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if self._match(TokenType.COLON): return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) return this - def _parse_case(self): + def _parse_case(self) -> t.Optional[exp.Expression]: ifs = [] default = None @@ -2311,7 +2529,7 @@ class Parser(metaclass=_Parser): self.expression(exp.Case, this=expression, ifs=ifs, default=default) ) - def _parse_if(self): + def _parse_if(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): args = self._parse_csv(self._parse_conjunction) this = exp.If.from_arg_list(args) @@ -2324,9 +2542,10 @@ class Parser(metaclass=_Parser): false = self._parse_conjunction() if self._match(TokenType.ELSE) else None self._match(TokenType.END) this = self.expression(exp.If, this=condition, true=true, false=false) + return self._parse_window(this) - def _parse_extract(self): + def _parse_extract(self) -> exp.Expression: this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): @@ -2337,7 +2556,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - def _parse_cast(self, strict): + def _parse_cast(self, strict: bool) -> exp.Expression: this = self._parse_conjunction() if not self._match(TokenType.ALIAS): @@ -2353,7 +2572,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_string_agg(self): + def _parse_string_agg(self) -> exp.Expression: + expression: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): args = self._parse_csv(self._parse_conjunction) expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) @@ -2380,8 +2601,10 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict): + def _parse_convert(self, strict: bool) -> exp.Expression: + to: t.Optional[exp.Expression] this = self._parse_column() + if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): @@ -2390,7 +2613,7 @@ class Parser(metaclass=_Parser): to = None return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_position(self): + def _parse_position(self) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): @@ -2402,11 +2625,11 @@ class Parser(metaclass=_Parser): return this - def _parse_join_hint(self, func_name): + def _parse_join_hint(self, func_name: str) -> exp.Expression: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) - def _parse_substring(self): + def _parse_substring(self) -> exp.Expression: # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -2422,7 +2645,7 @@ class Parser(metaclass=_Parser): return this - def _parse_trim(self): + def _parse_trim(self) -> exp.Expression: # https://www.w3resource.com/sql/character-functions/trim.php # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html @@ -2450,13 +2673,15 @@ class Parser(metaclass=_Parser): collation=collation, ) - def _parse_window_clause(self): + def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) - def _parse_named_window(self): + def _parse_named_window(self) -> t.Optional[exp.Expression]: return self._parse_window(self._parse_id_var(), alias=True) - def _parse_window(self, this, alias=False): + def _parse_window( + self, this: t.Optional[exp.Expression], alias: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) this = self.expression(exp.Filter, this=this, expression=where) @@ -2495,7 +2720,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) - alias = self._parse_id_var(False) + window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) partition = None if self._match(TokenType.PARTITION_BY): @@ -2529,10 +2754,10 @@ class Parser(metaclass=_Parser): partition_by=partition, order=order, spec=spec, - alias=alias, + alias=window_alias, ) - def _parse_window_spec(self): + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: self._match(TokenType.BETWEEN) return { @@ -2543,7 +2768,9 @@ class Parser(metaclass=_Parser): "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } - def _parse_alias(self, this, explicit=False): + def _parse_alias( + self, this: t.Optional[exp.Expression], explicit: bool = False + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) if explicit and not any_token: @@ -2565,63 +2792,74 @@ class Parser(metaclass=_Parser): return this - def _parse_id_var(self, any_token=True, tokens=None): + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + prefix_tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: identifier = self._parse_identifier() if identifier: return identifier + prefix = "" + + if prefix_tokens: + while self._match_set(prefix_tokens): + prefix += self._prev.text + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): - return exp.Identifier(this=self._prev.text, quoted=False) + return exp.Identifier(this=prefix + self._prev.text, quoted=False) return None - def _parse_string(self): + def _parse_string(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STRING): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() - def _parse_number(self): + def _parse_number(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NUMBER): return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() - def _parse_identifier(self): + def _parse_identifier(self) -> t.Optional[exp.Expression]: if self._match(TokenType.IDENTIFIER): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self, any_token=False): + def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]: if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() - def _advance_any(self): + def _advance_any(self) -> t.Optional[Token]: if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: self._advance() return self._prev return None - def _parse_var_or_string(self): + def _parse_var_or_string(self) -> t.Optional[exp.Expression]: return self._parse_var() or self._parse_string() - def _parse_null(self): + def _parse_null(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NULL): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None - def _parse_boolean(self): + def _parse_boolean(self) -> t.Optional[exp.Expression]: if self._match(TokenType.TRUE): return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None - def _parse_star(self): + def _parse_star(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STAR): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None - def _parse_placeholder(self): + def _parse_placeholder(self) -> t.Optional[exp.Expression]: if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): @@ -2630,18 +2868,20 @@ class Parser(metaclass=_Parser): self._advance(-1) return None - def _parse_except(self): + def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.EXCEPT): return None return self._parse_wrapped_id_vars() - def _parse_replace(self): + def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): return None return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) - def _parse_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: parse_result = parse_method() items = [parse_result] if parse_result is not None else [] @@ -2655,7 +2895,9 @@ class Parser(metaclass=_Parser): return items - def _parse_tokens(self, parse_method, expressions): + def _parse_tokens( + self, parse_method: t.Callable, expressions: t.Dict + ) -> t.Optional[exp.Expression]: this = parse_method() while self._match_set(expressions): @@ -2668,22 +2910,29 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self): + def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped_csv(self._parse_id_var) - def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_wrapped_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) - def _parse_wrapped(self, parse_method): + def _parse_wrapped(self, parse_method: t.Callable) -> t.Any: self._match_l_paren() parse_result = parse_method() self._match_r_paren() return parse_result - def _parse_select_or_expression(self): + def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: return self._parse_select() or self._parse_expression() - def _parse_transaction(self): + def _parse_ddl_select(self) -> t.Optional[exp.Expression]: + return self._parse_set_operations( + self._parse_select(nested=True, parse_subquery_alias=False) + ) + + def _parse_transaction(self) -> exp.Expression: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text @@ -2703,7 +2952,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) - def _parse_commit_or_rollback(self): + def _parse_commit_or_rollback(self) -> exp.Expression: chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -2722,27 +2971,30 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("ADD"): return None self._match(TokenType.COLUMN) exists_column = self._parse_exists(not_=True) expression = self._parse_column_def(self._parse_field(any_token=True)) - expression.set("exists", exists_column) + + if expression: + expression.set("exists", exists_column) + return expression - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") - def _parse_alter(self): + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return None exists = self._parse_exists() this = self._parse_table(schema=True) - actions = None + actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None if self._match_text_seq("ADD", advance=False): actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): @@ -2770,24 +3022,24 @@ class Parser(metaclass=_Parser): actions = ensure_list(actions) return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) - def _parse_show(self): - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + def _parse_show(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore if parser: return parser(self) self._advance() return self.expression(exp.Show, this=self._prev.text.upper()) - def _default_parse_set_item(self): + def _default_parse_set_item(self) -> exp.Expression: return self.expression( exp.SetItem, this=self._parse_statement(), ) - def _parse_set_item(self): - parser = self._find_parser(self.SET_PARSERS, self._set_trie) + def _parse_set_item(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore return parser(self) if parser else self._default_parse_set_item() - def _parse_merge(self): + def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) target = self._parse_table(schema=True) @@ -2835,10 +3087,12 @@ class Parser(metaclass=_Parser): expressions=whens, ) - def _parse_set(self): + def _parse_set(self) -> exp.Expression: return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) - def _find_parser(self, parsers, trie): + def _find_parser( + self, parsers: t.Dict[str, t.Callable], trie: t.Dict + ) -> t.Optional[t.Callable]: index = self._index this = [] while True: |