diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:33 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:33 +0000 |
commit | 75d158890b303b701c51f12b34c422fb823ba9aa (patch) | |
tree | 5f10a4a1eb612918ea94a934cfc9b9893ea19442 /sqlglot/parser.py | |
parent | Adding upstream version 10.5.6. (diff) | |
download | sqlglot-75d158890b303b701c51f12b34c422fb823ba9aa.tar.xz sqlglot-75d158890b303b701c51f12b34c422fb823ba9aa.zip |
Adding upstream version 10.5.10.upstream/10.5.10
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 272 |
1 files changed, 170 insertions, 102 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c97b19a..42777d1 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -40,22 +40,23 @@ class _Parser(type): class Parser(metaclass=_Parser): """ - Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` - and produces a parsed syntax tree. - - Args - error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE. - error_message_context (int): determines the amount of context to capture from - a query string when displaying the error message (in number of characters). + Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces + a parsed syntax tree. + + Args: + error_level: the desired error level. + Default: ErrorLevel.RAISE + error_message_context: determines the amount of context to capture from a + query string when displaying the error message (in number of characters). Default: 50. - index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list + index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. Default: 0 - alias_post_tablesample (bool): If the table alias comes after tablesample + alias_post_tablesample: If the table alias comes after tablesample. Default: False - max_errors (int): Maximum number of error messages to include in a raised ParseError. + max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3 - null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + null_ordering: Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small" """ @@ -109,6 +110,8 @@ class Parser(metaclass=_Parser): TokenType.TEXT, TokenType.MEDIUMTEXT, TokenType.LONGTEXT, + TokenType.MEDIUMBLOB, + TokenType.LONGBLOB, TokenType.BINARY, TokenType.VARBINARY, TokenType.JSON, @@ -176,6 +179,7 @@ class Parser(metaclass=_Parser): TokenType.DIV, TokenType.DISTKEY, TokenType.DISTSTYLE, + TokenType.END, TokenType.EXECUTE, TokenType.ENGINE, TokenType.ESCAPE, @@ -468,9 +472,6 @@ class Parser(metaclass=_Parser): TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), - TokenType.PARAMETER: lambda self, _: self.expression( - exp.Parameter, this=self._parse_var() or self._parse_primary() - ), TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), @@ -479,6 +480,16 @@ class Parser(metaclass=_Parser): TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } + PLACEHOLDER_PARSERS = { + TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), + TokenType.PARAMETER: lambda self: self.expression( + exp.Parameter, this=self._parse_var() or self._parse_primary() + ), + TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set((TokenType.NUMBER, TokenType.VAR)) + else None, + } + RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), TokenType.IN: lambda self, this: self._parse_in(this), @@ -601,8 +612,7 @@ class Parser(metaclass=_Parser): WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} - # allows tables to have special tokens as prefixes - TABLE_PREFIX_TOKENS: t.Set[TokenType] = set() + ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} STRICT_CAST = True @@ -677,7 +687,7 @@ class Parser(metaclass=_Parser): def parse_into( self, - expression_types: str | exp.Expression | t.Collection[exp.Expression | str], + expression_types: exp.IntoType, raw_tokens: t.List[Token], sql: t.Optional[str] = None, ) -> t.List[t.Optional[exp.Expression]]: @@ -820,24 +830,8 @@ class Parser(metaclass=_Parser): if self.error_level == ErrorLevel.IGNORE: return - for k in expression.args: - if k not in expression.arg_types: - self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}") - for k, mandatory in expression.arg_types.items(): - v = expression.args.get(k) - if mandatory and (v is None or (isinstance(v, list) and not v)): - self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - - if ( - args - and isinstance(expression, exp.Func) - and len(args) > len(expression.arg_types) - and not expression.is_var_len_args - ): - self.raise_error( - f"The number of provided arguments ({len(args)}) is greater than " - f"the maximum number of supported arguments ({len(expression.arg_types)})" - ) + for error_message in expression.error_messages(args): + self.raise_error(error_message) def _find_token(self, token: Token, sql: str) -> int: line = 1 @@ -868,6 +862,9 @@ class Parser(metaclass=_Parser): def _retreat(self, index: int) -> None: self._advance(index - self._index) + def _parse_command(self) -> exp.Expression: + return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -876,11 +873,7 @@ class Parser(metaclass=_Parser): return self.STATEMENT_PARSERS[self._prev.token_type](self) if self._match_set(Tokenizer.COMMANDS): - return self.expression( - exp.Command, - this=self._prev.text, - expression=self._parse_string(), - ) + return self._parse_command() expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() @@ -942,12 +935,18 @@ class Parser(metaclass=_Parser): no_primary_index = None indexes = None no_schema_binding = None + begin = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): - this = self._parse_user_defined_function() + this = self._parse_user_defined_function(kind=create_token.token_type) properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_select_or_expression() + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + expression = self._parse_statement() + + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index() elif create_token.token_type in ( @@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser): no_primary_index=no_primary_index, indexes=indexes, no_schema_binding=no_schema_binding, + begin=begin, ) def _parse_property(self) -> t.Optional[exp.Expression]: @@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema(exp.Literal.string("TABLE")) + value = self._parse_schema(exp.Var(this="TABLE")) else: value = self._parse_types() @@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser): return None index = self._parse_id_var() columns = None - if self._curr and self._curr.token_type == TokenType.L_PAREN: + if self._match(TokenType.L_PAREN, advance=False): columns = self._parse_wrapped_csv(self._parse_column) return self.expression( exp.Index, @@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser): amp=amp, ) + def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + catalog = None + db = None + table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False) + + while self._match(TokenType.DOT): + if catalog: + # This allows nesting the table in arbitrarily many dot expressions if needed + table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) + else: + catalog = db + db = table + table = self._parse_id_var() + + if not table: + self.raise_error(f"Expected table name but got {self._curr}") + + return self.expression( + exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() + ) + def _parse_table( self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None ) -> t.Optional[exp.Expression]: @@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser): if subquery: return subquery - catalog = None - db = None - table = (not schema and self._parse_function()) or self._parse_id_var( - any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS - ) - - while self._match(TokenType.DOT): - if catalog: - # This allows nesting the table in arbitrarily many dot expressions if needed - table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) - else: - catalog = db - db = table - table = self._parse_id_var() - - if not table: - self.raise_error(f"Expected table name but got {self._curr}") - - this = self.expression( - exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() - ) + this = self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) @@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser): expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - expression=self._parse_select(nested=True), + expression=self._parse_set_operations(self._parse_select(nested=True)), ) def _parse_expression(self) -> t.Optional[exp.Expression]: @@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) - def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: this = self._parse_id_var() while self._match(TokenType.DOT): @@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_udf_kwarg) self._match_r_paren() - return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + return self.expression( + exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True + ) def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: literal = self._parse_primary() @@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser): or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() - - if isinstance(this, exp.Literal): - this = this.name - return self.expression(exp.Schema, this=this, expressions=args) def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: @@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser): def _parse_unique(self) -> exp.Expression: return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) + def _parse_key_constraint_options(self) -> t.List[str]: + options = [] + while True: + if not self._curr: + break + + if self._match_text_seq("NOT", "ENFORCED"): + options.append("NOT ENFORCED") + elif self._match_text_seq("DEFERRABLE"): + options.append("DEFERRABLE") + elif self._match_text_seq("INITIALLY", "DEFERRED"): + options.append("INITIALLY DEFERRED") + elif self._match_text_seq("NORELY"): + options.append("NORELY") + elif self._match_text_seq("MATCH", "FULL"): + options.append("MATCH FULL") + elif self._match_text_seq("ON", "UPDATE", "NO ACTION"): + options.append("ON UPDATE NO ACTION") + elif self._match_text_seq("ON", "DELETE", "NO ACTION"): + options.append("ON DELETE NO ACTION") + else: + break + + return options + def _parse_references(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.REFERENCES): return None - return self.expression( - exp.Reference, - this=self._parse_id_var(), - expressions=self._parse_wrapped_id_vars(), - ) + expressions = None + this = self._parse_id_var() + + if self._match(TokenType.L_PAREN, advance=False): + expressions = self._parse_wrapped_id_vars() + + options = self._parse_key_constraint_options() + return self.expression(exp.Reference, this=this, expressions=expressions, options=options) def _parse_foreign_key(self) -> exp.Expression: expressions = self._parse_wrapped_id_vars() @@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser): options[kind] = action return self.expression( - exp.ForeignKey, - expressions=expressions, - reference=reference, - **options, # type: ignore + exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore ) + def _parse_primary_key(self) -> exp.Expression: + expressions = self._parse_wrapped_id_vars() + options = self._parse_key_constraint_options() + return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.L_BRACKET): return this @@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict: bool) -> exp.Expression: + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to: t.Optional[exp.Expression] this = self._parse_column() @@ -2641,19 +2672,25 @@ class Parser(metaclass=_Parser): to = self._parse_types() else: to = None + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_position(self) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): - args.append(self._parse_bitwise()) + return self.expression( + exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0) + ) - this = exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ) + if haystack_first: + haystack = seq_get(args, 0) + needle = seq_get(args, 1) + else: + needle = seq_get(args, 0) + haystack = seq_get(args, 1) + + this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2)) self.validate_expression(this, args) @@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser): return None def _parse_placeholder(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.PLACEHOLDER): - return self.expression(exp.Placeholder) - elif self._match(TokenType.COLON): - if self._match_set((TokenType.NUMBER, TokenType.VAR)): - return self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set(self.PLACEHOLDER_PARSERS): + placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self) + if placeholder: + return placeholder self._advance(-1) return None def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.EXCEPT): return None - - return self._parse_wrapped_id_vars() + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_id_vars() + return self._parse_csv(self._parse_id_var) def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): return None - return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_csv(self._parse_expression) + return self._parse_csv(self._parse_expression) def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA @@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser): def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + def _parse_add_constraint(self) -> t.Optional[exp.Expression]: + this = None + kind = self._prev.token_type + + if kind == TokenType.CONSTRAINT: + this = self._parse_id_var() + + if self._match(TokenType.CHECK): + expression = self._parse_wrapped(self._parse_conjunction) + enforced = self._match_text_seq("ENFORCED") + + return self.expression( + exp.AddConstraint, this=this, expression=expression, enforced=enforced + ) + + if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY): + expression = self._parse_foreign_key() + elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY): + expression = self._parse_primary_key() + + return self.expression(exp.AddConstraint, this=this, expression=expression) + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return None @@ -3029,8 +3090,14 @@ class Parser(metaclass=_Parser): this = self._parse_table(schema=True) actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None - if self._match_text_seq("ADD", advance=False): - actions = self._parse_csv(self._parse_add_column) + + index = self._index + if self._match_text_seq("ADD"): + if self._match_set(self.ADD_CONSTRAINT_TOKENS): + actions = self._parse_csv(self._parse_add_constraint) + else: + self._retreat(index) + actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): actions = self._parse_csv(self._parse_drop_column) elif self._match_text_seq("RENAME", "TO"): @@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser): def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) - target = self._parse_table(schema=True) + target = self._parse_table() self._match(TokenType.USING) using = self._parse_table() @@ -3146,12 +3213,13 @@ class Parser(metaclass=_Parser): self._retreat(index) return None - def _match(self, token_type): + def _match(self, token_type, advance=True): if not self._curr: return None if self._curr.token_type == token_type: - self._advance() + if advance: + self._advance() return True return None |