diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-04 07:24:05 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-04 07:24:05 +0000 |
commit | 621555af37594a213d91ea113d5fc7739af84d40 (patch) | |
tree | 5aaa3b586692062accffc21cfaaa5a3917ee77b3 /sqlglot/parser.py | |
parent | Adding upstream version 10.2.9. (diff) | |
download | sqlglot-621555af37594a213d91ea113d5fc7739af84d40.tar.xz sqlglot-621555af37594a213d91ea113d5fc7739af84d40.zip |
Adding upstream version 10.4.2.upstream/10.4.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 243 |
1 files changed, 185 insertions, 58 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 29bc9c0..308f363 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, seq_get +from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -117,6 +117,7 @@ class Parser(metaclass=_Parser): TokenType.GEOMETRY, TokenType.HLLSKETCH, TokenType.HSTORE, + TokenType.PSEUDO_TYPE, TokenType.SUPER, TokenType.SERIAL, TokenType.SMALLSERIAL, @@ -153,6 +154,7 @@ class Parser(metaclass=_Parser): TokenType.CACHE, TokenType.CASCADE, TokenType.COLLATE, + TokenType.COLUMN, TokenType.COMMAND, TokenType.COMMIT, TokenType.COMPOUND, @@ -169,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.ESCAPE, TokenType.FALSE, TokenType.FIRST, + TokenType.FILTER, TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FUNCTION, @@ -188,6 +191,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, + TokenType.OFFSET, TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, @@ -222,12 +226,18 @@ class Parser(metaclass=_Parser): TokenType.PROPERTIES, TokenType.PROCEDURE, TokenType.VOLATILE, + TokenType.WINDOW, *SUBQUERY_PREDICATES, *TYPE_TOKENS, *NO_PAREN_FUNCTIONS, } - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { + TokenType.APPLY, + TokenType.NATURAL, + TokenType.OFFSET, + TokenType.WINDOW, + } UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} @@ -257,6 +267,7 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.WINDOW, *TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -351,22 +362,27 @@ class Parser(metaclass=_Parser): TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, - path=path, + expression=path, ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, - path=path, + expression=path, ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, this=this, - path=path, + expression=path, ), TokenType.DHASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtractScalar, this=this, - path=path, + expression=path, + ), + TokenType.PLACEHOLDER: lambda self, this, key: self.expression( + exp.JSONBContains, + this=this, + expression=key, ), } @@ -392,25 +408,27 @@ class Parser(metaclass=_Parser): exp.Ordered: lambda self: self._parse_ordered(), exp.Having: lambda self: self._parse_having(), exp.With: lambda self: self._parse_with(), + exp.Window: lambda self: self._parse_named_window(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } STATEMENT_PARSERS = { + TokenType.ALTER: lambda self: self._parse_alter(), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), + TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.INSERT: lambda self: self._parse_insert(), TokenType.LOAD_DATA: lambda self: self._parse_load_data(), - TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.DELETE: lambda self: self._parse_delete(), - TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.UPDATE: lambda self: self._parse_update(), TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), - TokenType.BEGIN: lambda self: self._parse_transaction(), - TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), - TokenType.END: lambda self: self._parse_commit_or_rollback(), - TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), - TokenType.MERGE: lambda self: self._parse_merge(), } UNARY_PARSERS = { @@ -441,6 +459,7 @@ class Parser(metaclass=_Parser): TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.NATIONAL: lambda self, token: self._parse_national(token), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -454,6 +473,9 @@ class Parser(metaclass=_Parser): TokenType.ILIKE: lambda self, this: self._parse_escape( self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), + TokenType.IRLIKE: lambda self, this: self.expression( + exp.RegexpILike, this=this, expression=self._parse_bitwise() + ), TokenType.RLIKE: lambda self, this: self.expression( exp.RegexpLike, this=this, expression=self._parse_bitwise() ), @@ -535,8 +557,7 @@ class Parser(metaclass=_Parser): "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), + "windows": lambda self: self._parse_window_clause(), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -551,18 +572,18 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { - TokenType.TABLE, - TokenType.VIEW, + TokenType.COLUMN, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE, TokenType.SCHEMA, + TokenType.TABLE, + TokenType.VIEW, } TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} STRICT_CAST = True - LATERAL_FUNCTION_AS_VIEW = False __slots__ = ( "error_level", @@ -782,13 +803,16 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self): + def _parse_drop(self, default_kind=None): temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - self.raise_error(f"Expected {self.CREATABLES}") - return + if default_kind: + kind = default_kind + else: + self.raise_error(f"Expected {self.CREATABLES}") + return return self.expression( exp.Drop, @@ -876,7 +900,7 @@ class Parser(metaclass=_Parser): ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) if assignment: - key = self._parse_var() or self._parse_string() + key = self._parse_var_or_string() self._match(TokenType.EQ) return self.expression(exp.Property, this=key, value=self._parse_column()) @@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser): elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) + this = self._parse_set_operations(this) self._match_r_paren() - this = self._parse_subquery(this) + # early return so that subquery unions aren't parsed again + # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 + # Union ALL should be a property of the top select node, not the subquery + return self._parse_subquery(this) elif self._match(TokenType.VALUES): + if self._curr.token_type == TokenType.L_PAREN: + # We don't consume the left paren because it's consumed in _parse_value + expressions = self._parse_csv(self._parse_value) + else: + # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. + # Source: https://prestodb.io/docs/current/sql/values.html + expressions = self._parse_csv( + lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + ) + this = self.expression( exp.Values, - expressions=self._parse_csv(self._parse_value), + expressions=expressions, alias=self._parse_table_alias(), ) else: this = None - return self._parse_set_operations(this) if this else None + return self._parse_set_operations(this) def _parse_with(self, skip_with_token=False): if not skip_with_token and not self._match(TokenType.WITH): @@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS ) - columns = None if self._match(TokenType.L_PAREN): - columns = self._parse_csv(lambda: self._parse_id_var(any_token)) + columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var())) self._match_r_paren() + else: + columns = None if not alias and not columns: return None @@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) - columns = None - table_alias = None - if view or self.LATERAL_FUNCTION_AS_VIEW: - table_alias = self._parse_id_var(any_token=False) - if self._match(TokenType.ALIAS): - columns = self._parse_csv(self._parse_id_var) + if view: + table = self._parse_id_var(any_token=False) + columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] + table_alias = self.expression(exp.TableAlias, this=table, columns=columns) else: - self._match(TokenType.ALIAS) - table_alias = self._parse_id_var(any_token=False) - - if self._match(TokenType.L_PAREN): - columns = self._parse_csv(self._parse_id_var) - self._match_r_paren() + table_alias = self._parse_table_alias() expression = self.expression( exp.Lateral, this=this, view=view, outer=outer, - alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), + alias=table_alias, ) if outer_apply or cross_apply: @@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser): if negate: this = self.expression(exp.Not, this=this) + if self._match(TokenType.IS): + this = self._parse_is(this) + return this def _parse_is(self, this): @@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser): return None type_token = self._prev.token_type + + if type_token == TokenType.PSEUDO_TYPE: + return self.expression(exp.PseudoType, this=self._prev.text) + nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None @@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser): if value is None: value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + elif type_token == TokenType.INTERVAL: + value = self.expression(exp.Interval, unit=self._parse_var()) if maybe_func and check_func: index2 = self._index @@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser): def _parse_primary(self): if self._match_set(self.PRIMARY_PARSERS): - return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + token_type = self._prev.token_type + primary = self.PRIMARY_PARSERS[token_type](self, self._prev) + + if token_type == TokenType.STRING: + expressions = [primary] + while self._match(TokenType.STRING): + expressions.append(exp.Literal.string(self._prev.text)) + if len(expressions) > 1: + return self.expression(exp.Concat, expressions=expressions) + return primary if self._match_pair(TokenType.DOT, TokenType.NUMBER): return exp.Literal.number(f"0.{self._prev.text}") @@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=token.text) + def _parse_national(self, token): + return self.expression(exp.National, this=exp.Literal.string(token.text)) + def _parse_session_parameter(self): kind = None this = self._parse_id_var() or self._parse_primary() @@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_id_var) - self._match(TokenType.R_PAREN) + + if not self._match(TokenType.R_PAREN): + self._retreat(index) else: expressions = [self._parse_id_var()] @@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser): exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) ) else: - this = self._parse_conjunction() + this = self._parse_select_or_expression() if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) else: self._match(TokenType.RESPECT_NULLS) - return self._parse_alias(self._parse_limit(self._parse_order(this))) + return self._parse_limit(self._parse_order(this)) def _parse_schema(self, this=None): index = self._index @@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser): return this args = self._parse_csv( - lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + lambda: self._parse_constraint() + or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.ENCODE): kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.NULL): @@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_BRACKET): return this - expressions = self._parse_csv(self._parse_conjunction) + if self._match(TokenType.COLON): + expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] + else: + expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) if not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) @@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) + def _parse_slice(self, this): + if self._match(TokenType.COLON): + return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) + return this + def _parse_case(self): ifs = [] default = None @@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser): collation=collation, ) + def _parse_window_clause(self): + return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) + + def _parse_named_window(self): + return self._parse_window(self._parse_id_var(), alias=True) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) @@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser): if identifier: return identifier - if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - self._advance() - elif not self._match_set(tokens or self.ID_VAR_TOKENS): - return None - return exp.Identifier(this=self._prev.text, quoted=False) + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): + return exp.Identifier(this=self._prev.text, quoted=False) + return None def _parse_string(self): if self._match(TokenType.STRING): @@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self): - if self._match(TokenType.VAR): + def _parse_var(self, any_token=False): + if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() + def _advance_any(self): + if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: + self._advance() + return self._prev + return None + def _parse_var_or_string(self): return self._parse_var() or self._parse_string() @@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): - self._advance() - return self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set((TokenType.NUMBER, TokenType.VAR)): + return self.expression(exp.Placeholder, this=self._prev.text) + self._advance(-1) return None def _parse_except(self): @@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) + def _parse_add_column(self): + if not self._match_text_seq("ADD"): + return None + + self._match(TokenType.COLUMN) + exists_column = self._parse_exists(not_=True) + expression = self._parse_column_def(self._parse_field(any_token=True)) + expression.set("exists", exists_column) + return expression + + def _parse_drop_column(self): + return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + + def _parse_alter(self): + if not self._match(TokenType.TABLE): + return None + + exists = self._parse_exists() + this = self._parse_table(schema=True) + + actions = None + if self._match_text_seq("ADD", advance=False): + actions = self._parse_csv(self._parse_add_column) + elif self._match_text_seq("DROP", advance=False): + actions = self._parse_csv(self._parse_drop_column) + elif self._match_text_seq("ALTER"): + self._match(TokenType.COLUMN) + column = self._parse_field(any_token=True) + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + actions = self.expression(exp.AlterColumn, this=column, drop=True) + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + actions = self.expression( + exp.AlterColumn, this=column, default=self._parse_conjunction() + ) + else: + self._match_text_seq("SET", "DATA") + actions = self.expression( + exp.AlterColumn, + this=column, + dtype=self._match_text_seq("TYPE") and self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_conjunction(), + ) + + actions = ensure_list(actions) + return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) + def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) if parser: @@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser): return True return False - def _match_text_seq(self, *texts): + def _match_text_seq(self, *texts, advance=True): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: @@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser): else: self._retreat(index) return False + + if not advance: + self._retreat(index) + return True def _replace_columns_with_dots(self, this): |