diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
commit | 9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1 (patch) | |
tree | 7ab2f39fbb6fd832aeea5cef45b54bfd59ba5ba5 /sqlglot/parser.py | |
parent | Adding upstream version 9.0.6. (diff) | |
download | sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.tar.xz sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.zip |
Adding upstream version 10.0.1.upstream/10.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 410 |
1 files changed, 293 insertions, 117 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 79a1d90..bbea0e5 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import logging +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_errors -from sqlglot.helper import apply_index_offset, ensure_list, list_get +from sqlglot.helper import apply_index_offset, ensure_collection, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType +from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") @@ -20,7 +24,15 @@ def parse_var_map(args): ) -class Parser: +class _Parser(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + return klass + + +class Parser(metaclass=_Parser): """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` and produces a parsed syntax tree. @@ -45,16 +57,16 @@ class Parser: FUNCTIONS = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TIME_TO_TIME_STR": lambda args: exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( this=exp.Cast( - this=list_get(args, 0), + this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), start=exp.Literal.number(1), @@ -90,6 +102,7 @@ class Parser: TokenType.NVARCHAR, TokenType.TEXT, TokenType.BINARY, + TokenType.VARBINARY, TokenType.JSON, TokenType.INTERVAL, TokenType.TIMESTAMP, @@ -243,6 +256,7 @@ class Parser: EQUALITY = { TokenType.EQ: exp.EQ, TokenType.NEQ: exp.NEQ, + TokenType.NULLSAFE_EQ: exp.NullSafeEQ, } COMPARISON = { @@ -298,6 +312,21 @@ class Parser: TokenType.ANTI, } + LAMBDAS = { + TokenType.ARROW: lambda self, expressions: self.expression( + exp.Lambda, + this=self._parse_conjunction().transform( + self._replace_lambda, {node.name for node in expressions} + ), + expressions=expressions, + ), + TokenType.FARROW: lambda self, expressions: self.expression( + exp.Kwarg, + this=exp.Var(this=expressions[0].name), + expression=self._parse_conjunction(), + ), + } + COLUMN_OPERATORS = { TokenType.DOT: None, TokenType.DCOLON: lambda self, this, to: self.expression( @@ -362,20 +391,30 @@ class Parser: TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.USE: lambda self: self._parse_use(), } PRIMARY_PARSERS = { - TokenType.STRING: lambda _, token: exp.Literal.string(token.text), - TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), - TokenType.NULL: lambda *_: exp.Null(), - TokenType.TRUE: lambda *_: exp.Boolean(this=True), - TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), - TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), - TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), - TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text), + TokenType.STRING: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=True + ), + TokenType.NUMBER: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=False + ), + TokenType.STAR: lambda self, _: self.expression( + exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + ), + TokenType.NULL: lambda self, _: self.expression(exp.Null), + TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), + TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), + TokenType.PARAMETER: lambda self, _: self.expression( + exp.Parameter, this=self._parse_var() or self._parse_primary() + ), + TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), + TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), + TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } RANGE_PARSERS = { @@ -411,16 +450,24 @@ class Parser: TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), - TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), + TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment( + exp.TableFormatProperty + ), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), TokenType.EXECUTE: lambda self: self._parse_execute_as(), TokenType.DETERMINISTIC: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")), - TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")), - TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")), + TokenType.IMMUTABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + TokenType.STABLE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("STABLE") + ), + TokenType.VOLATILE: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + ), } CONSTRAINT_PARSERS = { @@ -450,7 +497,8 @@ class Parser: "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "window": lambda self: self._match(TokenType.WINDOW) + and self._parse_window(self._parse_id_var(), alias=True), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -459,6 +507,9 @@ class Parser: "offset": lambda self: self._parse_offset(), } + SHOW_PARSERS: t.Dict[str, t.Callable] = {} + SET_PARSERS: t.Dict[str, t.Callable] = {} + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { @@ -488,7 +539,9 @@ class Parser: "_curr", "_next", "_prev", - "_greedy_subqueries", + "_prev_comment", + "_show_trie", + "_set_trie", ) def __init__( @@ -519,7 +572,7 @@ class Parser: self._curr = None self._next = None self._prev = None - self._greedy_subqueries = False + self._prev_comment = None def parse(self, raw_tokens, sql=None): """ @@ -533,10 +586,12 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) def parse_into(self, expression_types, raw_tokens, sql=None): - for expression_type in ensure_list(expression_types): + for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: raise TypeError(f"No parser registered for {expression_type}") @@ -597,6 +652,9 @@ class Parser: def expression(self, exp_class, **kwargs): instance = exp_class(**kwargs) + if self._prev_comment: + instance.comment = self._prev_comment + self._prev_comment = None self.validate_expression(instance) return instance @@ -633,14 +691,16 @@ class Parser: return index - def _get_token(self, index): - return list_get(self._tokens, index) - def _advance(self, times=1): self._index += times - self._curr = self._get_token(self._index) - self._next = self._get_token(self._index + 1) - self._prev = self._get_token(self._index - 1) if self._index > 0 else None + self._curr = seq_get(self._tokens, self._index) + self._next = seq_get(self._tokens, self._index + 1) + if self._index > 0: + self._prev = self._tokens[self._index - 1] + self._prev_comment = self._prev.comment + else: + self._prev = None + self._prev_comment = None def _retreat(self, index): self._advance(index - self._index) @@ -661,6 +721,7 @@ class Parser: expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() + self._parse_query_modifiers(expression) return expression @@ -682,7 +743,11 @@ class Parser: ) def _parse_exists(self, not_=False): - return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) + return ( + self._match(TokenType.IF) + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -931,7 +996,9 @@ class Parser: return self.expression( exp.Delete, this=self._parse_table(schema=True), - using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)), + using=self._parse_csv( + lambda: self._match(TokenType.USING) and self._parse_table(schema=True) + ), where=self._parse_where(), ) @@ -983,11 +1050,13 @@ class Parser: return None def parse_values(): - k = self._parse_var() + key = self._parse_var() + value = None + if self._match(TokenType.EQ): - v = self._parse_string() - return (k, v) - return (k, None) + value = self._parse_string() + + return exp.Property(this=key, value=value) self._match_l_paren() values = self._parse_csv(parse_values) @@ -1019,6 +1088,8 @@ class Parser: self.raise_error(f"{this.key} does not support CTE") this = cte elif self._match(TokenType.SELECT): + comment = self._prev_comment + hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -1033,7 +1104,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) + expressions = self._parse_csv(self._parse_expression) this = self.expression( exp.Select, @@ -1042,6 +1113,7 @@ class Parser: expressions=expressions, limit=limit, ) + this.comment = comment from_ = self._parse_from() if from_: this.set("from", from_) @@ -1072,8 +1144,10 @@ class Parser: while True: expressions.append(self._parse_cte()) - if not self._match(TokenType.COMMA): + if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH): break + else: + self._match(TokenType.WITH) return self.expression( exp.With, @@ -1111,11 +1185,7 @@ class Parser: if not alias and not columns: return None - return self.expression( - exp.TableAlias, - this=alias, - columns=columns, - ) + return self.expression(exp.TableAlias, this=alias, columns=columns) def _parse_subquery(self, this): return self.expression( @@ -1150,12 +1220,6 @@ class Parser: if expression: this.set(key, expression) - def _parse_annotation(self, expression): - if self._match(TokenType.ANNOTATION): - return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression) - - return expression - def _parse_hint(self): if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) @@ -1295,7 +1359,9 @@ class Parser: if not table: self.raise_error("Expected table name") - this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()) + this = self.expression( + exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() + ) if schema: return self._parse_schema(this=this) @@ -1500,7 +1566,9 @@ class Parser: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this - return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) + return self.expression( + exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + ) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): @@ -1521,7 +1589,8 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") + or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): @@ -1606,6 +1675,9 @@ class Parser: def _parse_is(self, this): negate = self._match(TokenType.NOT) + if self._match(TokenType.DISTINCT_FROM): + klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ + return self.expression(klass, this=this, expression=self._parse_expression()) this = self.expression( exp.Is, this=this, @@ -1653,9 +1725,13 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) else: break @@ -1685,7 +1761,7 @@ class Parser: ) index = self._index - type_token = self._parse_types() + type_token = self._parse_types(check_func=True) this = self._parse_column() if type_token: @@ -1698,7 +1774,7 @@ class Parser: return this - def _parse_types(self): + def _parse_types(self, check_func=False): index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1708,10 +1784,13 @@ class Parser: nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None + maybe_func = False if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): return exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True + this=exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(type_token.value)], + nested=True, ) if self._match(TokenType.L_BRACKET): @@ -1731,6 +1810,7 @@ class Parser: return None self._match_r_paren() + maybe_func = True if nested and self._match(TokenType.LT): if is_struct: @@ -1741,25 +1821,46 @@ class Parser: if not self._match(TokenType.GT): self.raise_error("Expecting >") + value = None if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ - if tz: - return exp.DataType( + if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions, ) - ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ - if ltz: - return exp.DataType( + elif ( + self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + ): + value = exp.DataType( this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions, ) - self._match(TokenType.WITHOUT_TIME_ZONE) + elif self._match(TokenType.WITHOUT_TIME_ZONE): + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) - return exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + maybe_func = maybe_func and value is None + + if value is None: + value = exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) + + if maybe_func and check_func: + index2 = self._index + peek = self._parse_string() + + if not peek: + self._retreat(index) + return None + + self._retreat(index2) + + if value: + return value return exp.DataType( this=exp.DataType.Type[type_token.value.upper()], @@ -1826,22 +1927,29 @@ class Parser: return exp.Literal.number(f"0.{self._prev.text}") if self._match(TokenType.L_PAREN): + comment = self._prev_comment query = self._parse_select() if query: expressions = [query] else: - expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) + expressions = self._parse_csv( + lambda: self._parse_alias(self._parse_conjunction(), explicit=True) + ) - this = list_get(expressions, 0) + this = seq_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_set_operations(self._parse_subquery(this)) - if len(expressions) > 1: - return self.expression(exp.Tuple, expressions=expressions) - return self.expression(exp.Paren, this=this) + this = self._parse_set_operations(self._parse_subquery(this)) + elif len(expressions) > 1: + this = self.expression(exp.Tuple, expressions=expressions) + else: + this = self.expression(exp.Paren, this=this) + if comment: + this.comment = comment + return this return None @@ -1894,7 +2002,8 @@ class Parser: self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match_r_paren() + + self._match_r_paren(this) return self._parse_window(this) def _parse_user_defined_function(self): @@ -1920,6 +2029,18 @@ class Parser: return self.expression(exp.Identifier, this=token.text) + def _parse_session_parameter(self): + kind = None + this = self._parse_id_var() or self._parse_primary() + if self._match(TokenType.DOT): + kind = this.name + this = self._parse_var() or self._parse_primary() + return self.expression( + exp.SessionParameter, + this=this, + kind=kind, + ) + def _parse_udf_kwarg(self): this = self._parse_id_var() kind = self._parse_types() @@ -1938,27 +2059,24 @@ class Parser: else: expressions = [self._parse_id_var()] - if not self._match(TokenType.ARROW): - self._retreat(index) + if self._match_set(self.LAMBDAS): + return self.LAMBDAS[self._prev.token_type](self, expressions) - if self._match(TokenType.DISTINCT): - this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)) - else: - this = self._parse_conjunction() + self._retreat(index) - if self._match(TokenType.IGNORE_NULLS): - this = self.expression(exp.IgnoreNulls, this=this) - else: - self._match(TokenType.RESPECT_NULLS) + if self._match(TokenType.DISTINCT): + this = self.expression( + exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) + ) + else: + this = self._parse_conjunction() - return self._parse_alias(self._parse_limit(self._parse_order(this))) + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + else: + self._match(TokenType.RESPECT_NULLS) - conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions}) - return self.expression( - exp.Lambda, - this=conjunction, - expressions=expressions, - ) + return self._parse_alias(self._parse_limit(self._parse_order(this))) def _parse_schema(self, this=None): index = self._index @@ -1966,7 +2084,9 @@ class Parser: self._retreat(index) return this - args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) + args = self._parse_csv( + lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2104,6 +2224,7 @@ class Parser: if not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") + this.comment = self._prev_comment return self._parse_bracket(this) def _parse_case(self): @@ -2124,7 +2245,9 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) + return self._parse_window( + self.expression(exp.Case, this=expression, ifs=ifs, default=default) + ) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -2331,7 +2454,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) + "value": ( + self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text + ) or self._parse_bitwise(), "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } @@ -2348,7 +2473,7 @@ class Parser: this=this, expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), ) - self._match_r_paren() + self._match_r_paren(aliases) return aliases alias = self._parse_id_var(any_token) @@ -2365,28 +2490,29 @@ class Parser: return identifier if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - - return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) + self._advance() + elif not self._match_set(tokens or self.ID_VAR_TOKENS): + return None + return exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): - return exp.Literal.string(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() def _parse_number(self): if self._match(TokenType.NUMBER): - return exp.Literal.number(self._prev.text) + return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() def _parse_identifier(self): if self._match(TokenType.IDENTIFIER): - return exp.Identifier(this=self._prev.text, quoted=True) + return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() def _parse_var(self): if self._match(TokenType.VAR): - return exp.Var(this=self._prev.text) + return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() def _parse_var_or_string(self): @@ -2394,27 +2520,27 @@ class Parser: def _parse_null(self): if self._match(TokenType.NULL): - return exp.Null() + return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None def _parse_boolean(self): if self._match(TokenType.TRUE): - return exp.Boolean(this=True) + return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): - return exp.Boolean(this=False) + return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) + return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): - return exp.Placeholder() + return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): self._advance() - return exp.Placeholder(this=self._prev.text) + return self.expression(exp.Placeholder, this=self._prev.text) return None def _parse_except(self): @@ -2432,22 +2558,27 @@ class Parser: self._match_r_paren() return columns - def _parse_csv(self, parse): - parse_result = parse() + def _parse_csv(self, parse_method): + parse_result = parse_method() items = [parse_result] if parse_result is not None else [] while self._match(TokenType.COMMA): - parse_result = parse() + if parse_result and self._prev_comment is not None: + parse_result.comment = self._prev_comment + + parse_result = parse_method() if parse_result is not None: items.append(parse_result) return items - def _parse_tokens(self, parse, expressions): - this = parse() + def _parse_tokens(self, parse_method, expressions): + this = parse_method() while self._match_set(expressions): - this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) + this = self.expression( + expressions[self._prev.token_type], this=this, expression=parse_method() + ) return this @@ -2460,6 +2591,47 @@ class Parser: def _parse_select_or_expression(self): return self._parse_select() or self._parse_expression() + def _parse_use(self): + return self.expression(exp.Use, this=self._parse_id_var()) + + def _parse_show(self): + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + if parser: + return parser(self) + self._advance() + return self.expression(exp.Show, this=self._prev.text.upper()) + + def _default_parse_set_item(self): + return self.expression( + exp.SetItem, + this=self._parse_statement(), + ) + + def _parse_set_item(self): + parser = self._find_parser(self.SET_PARSERS, self._set_trie) + return parser(self) if parser else self._default_parse_set_item() + + def _parse_set(self): + return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + + def _find_parser(self, parsers, trie): + index = self._index + this = [] + while True: + # The current token might be multiple words + curr = self._curr.text.upper() + key = curr.split(" ") + this.append(curr) + self._advance() + result, trie = in_trie(trie, key) + if result == 0: + break + if result == 2: + subparser = parsers[" ".join(this)] + return subparser + self._retreat(index) + return None + def _match(self, token_type): if not self._curr: return None @@ -2491,13 +2663,17 @@ class Parser: return None - def _match_l_paren(self): + def _match_l_paren(self, expression=None): if not self._match(TokenType.L_PAREN): self.raise_error("Expecting (") + if expression and self._prev_comment: + expression.comment = self._prev_comment - def _match_r_paren(self): + def _match_r_paren(self, expression=None): if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") + if expression and self._prev_comment: + expression.comment = self._prev_comment def _match_text(self, *texts): index = self._index |