diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:00 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:00 +0000 |
commit | 74b38d30f43f7005428e09fa80508c5f21324c99 (patch) | |
tree | 7a0d4e49fffdc0330fc941c6528d3c8669a2acc6 /sqlglot/parser.py | |
parent | Adding upstream version 6.2.8. (diff) | |
download | sqlglot-74b38d30f43f7005428e09fa80508c5f21324c99.tar.xz sqlglot-74b38d30f43f7005428e09fa80508c5f21324c99.zip |
Adding upstream version 6.3.1.upstream/6.3.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 121 |
1 files changed, 78 insertions, 43 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5f20afc..c29e520 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType logger = logging.getLogger("sqlglot") +def parse_var_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return exp.VarMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + class Parser: """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` @@ -48,6 +60,7 @@ class Parser: start=exp.Literal.number(1), length=exp.Literal.number(10), ), + "VAR_MAP": parse_var_map, } NO_PAREN_FUNCTIONS = { @@ -117,6 +130,7 @@ class Parser: TokenType.VAR, TokenType.ALTER, TokenType.ALWAYS, + TokenType.ANTI, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, @@ -164,6 +178,7 @@ class Parser: TokenType.ROWS, TokenType.SCHEMA_COMMENT, TokenType.SEED, + TokenType.SEMI, TokenType.SET, TokenType.SHOW, TokenType.STABLE, @@ -273,6 +288,8 @@ class Parser: TokenType.INNER, TokenType.OUTER, TokenType.CROSS, + TokenType.SEMI, + TokenType.ANTI, } COLUMN_OPERATORS = { @@ -318,6 +335,8 @@ class Parser: exp.Properties: lambda self: self._parse_properties(), exp.Where: lambda self: self._parse_where(), exp.Ordered: lambda self: self._parse_ordered(), + exp.Having: lambda self: self._parse_having(), + exp.With: lambda self: self._parse_with(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -338,7 +357,6 @@ class Parser: TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), @@ -910,7 +928,20 @@ class Parser: return self.expression(exp.Tuple, expressions=expressions) def _parse_select(self, nested=False, table=False): - if self._match(TokenType.SELECT): + cte = self._parse_with() + if cte: + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + if "with" in this.arg_types: + this.set("with", cte) + else: + self.raise_error(f"{this.key} does not support CTE") + this = cte + elif self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -938,39 +969,6 @@ class Parser: if from_: this.set("from", from_) self._parse_query_modifiers(this) - elif self._match(TokenType.WITH): - recursive = self._match(TokenType.RECURSIVE) - - expressions = [] - - while True: - expressions.append(self._parse_cte()) - - if not self._match(TokenType.COMMA): - break - - cte = self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ) - this = self._parse_statement() - - if not this: - self.raise_error("Failed to parse any statement following CTE") - return cte - - if "with" in this.arg_types: - this.set( - "with", - self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ), - ) - else: - self.raise_error(f"{this.key} does not support CTE") elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) @@ -986,6 +984,26 @@ class Parser: return self._parse_set_operations(this) if this else None + def _parse_with(self): + if not self._match(TokenType.WITH): + return None + + recursive = self._match(TokenType.RECURSIVE) + + expressions = [] + + while True: + expressions.append(self._parse_cte()) + + if not self._match(TokenType.COMMA): + break + + return self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ) + def _parse_cte(self): alias = self._parse_table_alias() if not alias or not alias.this: @@ -1485,8 +1503,7 @@ class Parser: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) - else: - self._match_l_paren() + elif self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_select_or_expression) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): @@ -1495,6 +1512,9 @@ class Parser: this = self.expression(exp.In, this=this, expressions=expressions) self._match_r_paren() + else: + this = self.expression(exp.In, this=this, field=self._parse_field()) + return this def _parse_between(self, this): @@ -1591,7 +1611,7 @@ class Parser: elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_number) + expressions = self._parse_csv(self._parse_type) if not expressions: self._retreat(index) @@ -1706,7 +1726,7 @@ class Parser: def _parse_field(self, any_token=False): return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self): + def _parse_function(self, functions=None): if not self._curr: return None @@ -1742,7 +1762,9 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(upper) + if functions is None: + functions = self.FUNCTIONS + function = functions.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -2025,10 +2047,20 @@ class Parser: return self.expression(exp.Cast, this=this, to=to) def _parse_position(self): - substr = self._parse_bitwise() + args = self._parse_csv(self._parse_bitwise) + if self._match(TokenType.IN): - string = self._parse_bitwise() - return self.expression(exp.StrPosition, this=string, substr=substr) + args.append(self._parse_bitwise()) + + # Note: we're parsing in order needle, haystack, position + this = exp.StrPosition.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_join_hint(self, func_name): + args = self._parse_csv(self._parse_table) + return exp.JoinHint(this=func_name.upper(), expressions=args) def _parse_substring(self): # Postgres supports the form: substring(string [from int] [for int]) @@ -2247,6 +2279,9 @@ class Parser: def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): return exp.Placeholder() + elif self._match(TokenType.COLON): + self._advance() + return exp.Placeholder(this=self._prev.text) return None def _parse_except(self): |