diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:37 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:37 +0000 |
commit | 5f8be2e0852f3c925fb873a48946caee3050899f (patch) | |
tree | 1f31666277e226f47180321c08be7ebbedc2780e /sqlglot/parser.py | |
parent | Adding upstream version 20.9.0. (diff) | |
download | sqlglot-5f8be2e0852f3c925fb873a48946caee3050899f.tar.xz sqlglot-5f8be2e0852f3c925fb873a48946caee3050899f.zip |
Adding upstream version 20.11.0.upstream/20.11.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 161 |
1 files changed, 120 insertions, 41 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 790ee0d..c091605 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: - from typing_extensions import Literal - - from sqlglot._typing import E + from sqlglot._typing import E, Lit from sqlglot.dialects.dialect import Dialect, DialectType logger = logging.getLogger("sqlglot") @@ -148,6 +146,11 @@ class Parser(metaclass=_Parser): TokenType.ENUM16, } + AGGREGATE_TYPE_TOKENS = { + TokenType.AGGREGATEFUNCTION, + TokenType.SIMPLEAGGREGATEFUNCTION, + } + TYPE_TOKENS = { TokenType.BIT, TokenType.BOOLEAN, @@ -241,6 +244,7 @@ class Parser(metaclass=_Parser): TokenType.NULL, *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, + *AGGREGATE_TYPE_TOKENS, } SIGNED_TO_UNSIGNED_TYPE_TOKEN = { @@ -653,9 +657,11 @@ class Parser(metaclass=_Parser): PLACEHOLDER_PARSERS = { TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), TokenType.PARAMETER: lambda self: self._parse_parameter(), - TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) - if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) - else None, + TokenType.COLON: lambda self: ( + self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None + ), } RANGE_PARSERS = { @@ -705,6 +711,9 @@ class Parser(metaclass=_Parser): "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), + "INHERITS": lambda self: self.expression( + exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table) + ), "INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()), "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), @@ -822,6 +831,7 @@ class Parser(metaclass=_Parser): ALTER_PARSERS = { "ADD": lambda self: self._parse_alter_table_add(), "ALTER": lambda self: self._parse_alter_table_alter(), + "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True), "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), "DROP": lambda self: self._parse_alter_table_drop(), "RENAME": lambda self: self._parse_alter_table_rename(), @@ -973,6 +983,9 @@ class Parser(metaclass=_Parser): MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} + # parses no parenthesis if statements as commands + NO_PAREN_IF_COMMANDS = True + __slots__ = ( "error_level", "error_message_context", @@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser): if index != self._index: self._advance(index - self._index) + def _warn_unsupported(self) -> None: + if len(self._tokens) <= 1: + return + + # We use _find_sql because self.sql may comprise multiple chunks, and we're only + # interested in emitting a warning for the one being currently processed. + sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context] + + logger.warning( + f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'." + ) + def _parse_command(self) -> exp.Command: + self._warn_unsupported() return self.expression( exp.Command, this=self._prev.text.upper(), expression=self._parse_string() ) @@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser): start = self._prev comments = self._prev_comments - replace = start.text.upper() == "REPLACE" or self._match_pair( - TokenType.OR, TokenType.REPLACE + replace = ( + start.token_type == TokenType.REPLACE + or self._match_pair(TokenType.OR, TokenType.REPLACE) + or self._match_pair(TokenType.OR, TokenType.ALTER) ) unique = self._match(TokenType.UNIQUE) @@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser): exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy ) + if self._curr: + return self._parse_as_command(start) + return self.expression( exp.Create, comments=comments, @@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser): return self.expression( exp.FileFormatProperty, - this=self.expression( - exp.InputOutputFormat, input_format=input_format, output_format=output_format - ) - if input_format or output_format - else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + this=( + self.expression( + exp.InputOutputFormat, input_format=input_format, output_format=output_format + ) + if input_format or output_format + else self._parse_var_or_string() or self._parse_number() or self._parse_id_var() + ), ) def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E: @@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser): return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - def _parse_cluster(self) -> exp.Cluster: - return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered)) + def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster: + return self.expression( + exp.Cluster, + expressions=( + self._parse_wrapped_csv(self._parse_ordered) + if wrapped + else self._parse_csv(self._parse_ordered) + ), + ) def _parse_clustered_by(self) -> exp.ClusteredByProperty: self._match_text_seq("BY") @@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser): else: columns = None + include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None + return self.expression( exp.Index, this=index, @@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser): unique=unique, primary=primary, amp=amp, + include=include, partition_by=self._parse_partition_by(), where=self._parse_where(), ) @@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser): def _parse_comparison(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_range, self.COMPARISON) - def _parse_range(self) -> t.Optional[exp.Expression]: - this = self._parse_bitwise() + def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: + this = this or self._parse_bitwise() negate = self._match(TokenType.NOT) if self._match_set(self.RANGE_PARSERS): @@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_factor, self.TERM) def _parse_factor(self) -> t.Optional[exp.Expression]: - if self.EXPONENT: - factor = self._parse_tokens(self._parse_exponent, self.FACTOR) - else: - factor = self._parse_tokens(self._parse_unary, self.FACTOR) - if isinstance(factor, exp.Div): - factor.args["typed"] = self.dialect.TYPED_DIVISION - factor.args["safe"] = self.dialect.SAFE_DIVISION - return factor + parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary + this = parse_method() + + while self._match_set(self.FACTOR): + this = self.expression( + self.FACTOR[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=parse_method(), + ) + if isinstance(this, exp.Div): + this.args["typed"] = self.dialect.TYPED_DIVISION + this.args["safe"] = self.dialect.SAFE_DIVISION + + return this def _parse_exponent(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_unary, self.EXPONENT) @@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser): return exp.DataType.build(type_name, udt=True) else: + self._retreat(self._index - 1) return None else: return None @@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser): nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token in self.STRUCT_TYPE_TOKENS + is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS expressions = None maybe_func = False @@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser): ) elif type_token in self.ENUM_TYPE_TOKENS: expressions = self._parse_csv(self._parse_equality) + elif is_aggregate: + func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var( + any_token=False, tokens=(TokenType.VAR,) + ) + if not func_or_ident or not self._match(TokenType.COMMA): + return None + expressions = self._parse_csv( + lambda: self._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) + ) + expressions.insert(0, func_or_ident) else: expressions = self._parse_csv(self._parse_type_size) @@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser): self._match_r_paren() else: index = self._index - 1 + + if self.NO_PAREN_IF_COMMANDS and index == 0: + return self._parse_as_command(self._prev) + condition = self._parse_conjunction() if not condition: @@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser): return None @t.overload - def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): star = self._parse_star() @@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser): if alias: this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) + column = this.this # Moves the comment next to the alias in `expr /* comment */ AS alias` - if not this.comments and this.this.comments: - this.comments = this.this.comments - this.this.comments = None + if not this.comments and column and column.comments: + this.comments = column.comments + column.comments = None return this @@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser): if self._match_text_seq("CHECK"): expression = self._parse_wrapped(self._parse_conjunction) - enforced = self._match_text_seq("ENFORCED") + enforced = self._match_text_seq("ENFORCED") or False return self.expression( exp.AddConstraint, this=this, expression=expression, enforced=enforced @@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser): return self.expression(exp.AlterColumn, this=column, drop=True) if self._match_pair(TokenType.SET, TokenType.DEFAULT): return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction()) + if self._match(TokenType.COMMENT): + return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) self._match_text_seq("SET", "DATA") return self.expression( @@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_drop_column) - def _parse_alter_table_rename(self) -> exp.RenameTable: + def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]: + if self._match(TokenType.COLUMN): + exists = self._parse_exists() + old_column = self._parse_column() + to = self._match_text_seq("TO") + new_column = self._parse_column() + + if old_column is None or to is None or new_column is None: + return None + + return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists) + self._match_text_seq("TO") return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) @@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser): if parser: actions = ensure_list(parser(self)) - if not self._curr: + if not self._curr and actions: return self.expression( exp.AlterTable, this=this, @@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser): self._advance() text = self._find_sql(start, self._prev) size = len(start.text) + self._warn_unsupported() return exp.Command(this=text[:size], expression=text[size:]) def _parse_dict_property(self, this: str) -> exp.DictProperty: @@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser): if advance: self._advance() return True - return False + return None def _match_text_seq(self, *texts, advance=True): index = self._index @@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser): self._advance() else: self._retreat(index) - return False + return None if not advance: self._retreat(index) @@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser): return True @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: - ... + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... @t.overload def _replace_columns_with_dots( self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - ... + ) -> t.Optional[exp.Expression]: ... def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): |