diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:09 +0000 |
commit | 52f4a5e2260f3e5b919b4e270339afd670bf0b8a (patch) | |
tree | 5ca419af0e2e409018492b82f5b9847f0112b5fb /sqlglot/parser.py | |
parent | Adding upstream version 16.7.7. (diff) | |
download | sqlglot-52f4a5e2260f3e5b919b4e270339afd670bf0b8a.tar.xz sqlglot-52f4a5e2260f3e5b919b4e270339afd670bf0b8a.zip |
Adding upstream version 17.2.0.upstream/17.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 158 |
1 files changed, 105 insertions, 53 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 79e7cac..f7fd6ba 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -737,19 +737,29 @@ class Parser(metaclass=_Parser): } QUERY_MODIFIER_PARSERS = { - "joins": lambda self: list(iter(self._parse_join, None)), - "laterals": lambda self: list(iter(self._parse_lateral, None)), - "match": lambda self: self._parse_match_recognize(), - "where": lambda self: self._parse_where(), - "group": lambda self: self._parse_group(), - "having": lambda self: self._parse_having(), - "qualify": lambda self: self._parse_qualify(), - "windows": lambda self: self._parse_window_clause(), - "order": lambda self: self._parse_order(), - "limit": lambda self: self._parse_limit(), - "offset": lambda self: self._parse_offset(), - "locks": lambda self: self._parse_locks(), - "sample": lambda self: self._parse_table_sample(as_modifier=True), + TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()), + TokenType.WHERE: lambda self: ("where", self._parse_where()), + TokenType.GROUP_BY: lambda self: ("group", self._parse_group()), + TokenType.HAVING: lambda self: ("having", self._parse_having()), + TokenType.QUALIFY: lambda self: ("qualify", self._parse_qualify()), + TokenType.WINDOW: lambda self: ("windows", self._parse_window_clause()), + TokenType.ORDER_BY: lambda self: ("order", self._parse_order()), + TokenType.LIMIT: lambda self: ("limit", self._parse_limit()), + TokenType.FETCH: lambda self: ("limit", self._parse_limit()), + TokenType.OFFSET: lambda self: ("offset", self._parse_offset()), + TokenType.FOR: lambda self: ("locks", self._parse_locks()), + TokenType.LOCK: lambda self: ("locks", self._parse_locks()), + TokenType.TABLE_SAMPLE: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), + TokenType.USING: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), + TokenType.CLUSTER_BY: lambda self: ( + "cluster", + self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), + ), + TokenType.DISTRIBUTE_BY: lambda self: ( + "distribute", + self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), + ), + TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)), } SET_PARSERS = { @@ -1679,6 +1689,7 @@ class Parser(metaclass=_Parser): def _parse_insert(self) -> exp.Insert: overwrite = self._match(TokenType.OVERWRITE) + ignore = self._match(TokenType.IGNORE) local = self._match_text_seq("LOCAL") alternative = None @@ -1709,6 +1720,7 @@ class Parser(metaclass=_Parser): returning=self._parse_returning(), overwrite=overwrite, alternative=alternative, + ignore=ignore, ) def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: @@ -1734,7 +1746,8 @@ class Parser(metaclass=_Parser): nothing = True else: self._match(TokenType.UPDATE) - expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) + self._match(TokenType.SET) + expressions = self._parse_csv(self._parse_equality) return self.expression( exp.OnConflict, @@ -1805,12 +1818,17 @@ class Parser(metaclass=_Parser): return self._parse_as_command(self._prev) def _parse_delete(self) -> exp.Delete: - self._match(TokenType.FROM) + # This handles MySQL's "Multiple-Table Syntax" + # https://dev.mysql.com/doc/refman/8.0/en/delete.html + tables = None + if not self._match(TokenType.FROM, advance=False): + tables = self._parse_csv(self._parse_table) or None return self.expression( exp.Delete, - this=self._parse_table(), - using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), + tables=tables, + this=self._match(TokenType.FROM) and self._parse_table(joins=True), + using=self._match(TokenType.USING) and self._parse_table(joins=True), where=self._parse_where(), returning=self._parse_returning(), limit=self._parse_limit(), @@ -1822,7 +1840,7 @@ class Parser(metaclass=_Parser): **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), - "from": self._parse_from(modifiers=True), + "from": self._parse_from(joins=True), "where": self._parse_where(), "returning": self._parse_returning(), "limit": self._parse_limit(), @@ -1875,7 +1893,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Tuple, expressions=expressions) # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. - # Source: https://prestodb.io/docs/current/sql/values.html + # https://prestodb.io/docs/current/sql/values.html return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) def _parse_select( @@ -1917,7 +1935,7 @@ class Parser(metaclass=_Parser): self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv(self._parse_expression) + expressions = self._parse_expressions() this = self.expression( exp.Select, @@ -2034,20 +2052,31 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: if isinstance(this, self.MODIFIABLES): - for key, parser in self.QUERY_MODIFIER_PARSERS.items(): - expression = parser(self) - - if expression: - if key == "limit": - offset = expression.args.pop("offset", None) - if offset: - this.set("offset", exp.Offset(expression=offset)) - this.set(key, expression) + for join in iter(self._parse_join, None): + this.append("joins", join) + for lateral in iter(self._parse_lateral, None): + this.append("laterals", lateral) + + while True: + if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False): + parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type] + key, expression = parser(self) + + if expression: + this.set(key, expression) + if key == "limit": + offset = expression.args.pop("offset", None) + if offset: + this.set("offset", exp.Offset(expression=offset)) + continue + break return this def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): - hints = self._parse_csv(self._parse_function) + hints = [] + for hint in iter(lambda: self._parse_csv(self._parse_function), []): + hints.extend(hint) if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") @@ -2069,18 +2098,13 @@ class Parser(metaclass=_Parser): ) def _parse_from( - self, modifiers: bool = False, skip_from_token: bool = False + self, joins: bool = False, skip_from_token: bool = False ) -> t.Optional[exp.From]: if not skip_from_token and not self._match(TokenType.FROM): return None - comments = self._prev_comments - this = self._parse_table() - return self.expression( - exp.From, - comments=comments, - this=self._parse_query_modifiers(this) if modifiers else this, + exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) ) def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: @@ -2091,9 +2115,7 @@ class Parser(metaclass=_Parser): partition = self._parse_partition_by() order = self._parse_order() - measures = ( - self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None - ) + measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): rows = exp.var("ONE ROW PER MATCH") @@ -2259,6 +2281,18 @@ class Parser(metaclass=_Parser): kwargs["on"] = self._parse_conjunction() elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() + elif not (kind and kind.token_type == TokenType.CROSS): + index = self._index + joins = self._parse_joins() + + if joins and self._match(TokenType.ON): + kwargs["on"] = self._parse_conjunction() + elif joins and self._match(TokenType.USING): + kwargs["using"] = self._parse_wrapped_id_vars() + else: + joins = None + self._retreat(index) + kwargs["this"].set("joins", joins) return self.expression(exp.Join, **kwargs) @@ -2363,7 +2397,10 @@ class Parser(metaclass=_Parser): ) def _parse_table( - self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -2407,6 +2444,10 @@ class Parser(metaclass=_Parser): table_sample.set("this", this) this = table_sample + if joins: + for join in iter(self._parse_join, None): + this.append("joins", join) + return this def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: @@ -2507,8 +2548,11 @@ class Parser(metaclass=_Parser): kind=kind, ) - def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: - return list(iter(self._parse_pivot, None)) + def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: + return list(iter(self._parse_pivot, None)) or None + + def _parse_joins(self) -> t.Optional[t.List[exp.Join]]: + return list(iter(self._parse_join, None)) or None # https://duckdb.org/docs/sql/statements/pivot def _parse_simplified_pivot(self) -> exp.Pivot: @@ -2603,6 +2647,9 @@ class Parser(metaclass=_Parser): elements = defaultdict(list) + if self._match(TokenType.ALL): + return self.expression(exp.Group, all=True) + while True: expressions = self._parse_csv(self._parse_conjunction) if expressions: @@ -3171,7 +3218,7 @@ class Parser(metaclass=_Parser): if query: expressions = [query] else: - expressions = self._parse_csv(self._parse_expression) + expressions = self._parse_expressions() this = self._parse_query_modifiers(seq_get(expressions, 0)) @@ -3536,11 +3583,7 @@ class Parser(metaclass=_Parser): return None expressions = None - this = self._parse_id_var() - - if self._match(TokenType.L_PAREN, advance=False): - expressions = self._parse_wrapped_id_vars() - + this = self._parse_table(schema=True) options = self._parse_key_constraint_options() return self.expression(exp.Reference, this=this, expressions=expressions, options=options) @@ -3706,21 +3749,27 @@ class Parser(metaclass=_Parser): if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) elif self._match(TokenType.FORMAT): - fmt = self._parse_at_time_zone(self._parse_string()) + fmt_string = self._parse_string() + fmt = self._parse_at_time_zone(fmt_string) if to.this in exp.DataType.TEMPORAL_TYPES: - return self.expression( + this = self.expression( exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, this=this, format=exp.Literal.string( format_time( - fmt.this if fmt else "", + fmt_string.this if fmt_string else "", self.FORMAT_MAPPING or self.TIME_MAPPING, self.FORMAT_TRIE or self.TIME_TRIE, ) ), ) + if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime): + this.set("zone", fmt.args["zone"]) + + return this + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt) def _parse_concat(self) -> t.Optional[exp.Expression]: @@ -4223,7 +4272,7 @@ class Parser(metaclass=_Parser): return None if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_expression) - return self._parse_csv(self._parse_expression) + return self._parse_expressions() def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA @@ -4273,6 +4322,9 @@ class Parser(metaclass=_Parser): self._match_r_paren() return parse_result + def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]: + return self._parse_csv(self._parse_expression) + def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: return self._parse_select() or self._parse_set_operations( self._parse_expression() if alias else self._parse_conjunction() |