diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-23 07:22:20 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-23 07:22:20 +0000 |
commit | 41e67f6ce6b4b732d02e421d6825c18b8d15a59d (patch) | |
tree | 30fb0000d3e6ff11b366567bc35564842e7dbb50 /sqlglot/parser.py | |
parent | Adding upstream version 23.16.0. (diff) | |
download | sqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.tar.xz sqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.zip |
Adding upstream version 24.0.0.upstream/24.0.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 107 |
1 files changed, 93 insertions, 14 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 67ffd8f..3237cd1 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -41,11 +41,17 @@ def build_like(args: t.List) -> exp.Escape | exp.Like: def binary_range_parser( - expr_type: t.Type[exp.Expression], + expr_type: t.Type[exp.Expression], reverse_args: bool = False ) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: - return lambda self, this: self._parse_escape( - self.expression(expr_type, this=this, expression=self._parse_bitwise()) - ) + def _parse_binary_range( + self: Parser, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + expression = self._parse_bitwise() + if reverse_args: + this, expression = expression, this + return self._parse_escape(self.expression(expr_type, this=this, expression=expression)) + + return _parse_binary_range def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: @@ -335,6 +341,8 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TAG, TokenType.VIEW, + TokenType.WAREHOUSE, + TokenType.STREAMLIT, } CREATABLES = { @@ -418,6 +426,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.TRUNCATE, TokenType.UNIQUE, + TokenType.UNNEST, TokenType.UNPIVOT, TokenType.UPDATE, TokenType.USE, @@ -580,7 +589,7 @@ class Parser(metaclass=_Parser): exp.Lambda, this=self._replace_lambda( self._parse_conjunction(), - {node.name for node in expressions}, + expressions, ), expressions=expressions, ), @@ -1160,6 +1169,9 @@ class Parser(metaclass=_Parser): # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres) JSON_ARROWS_REQUIRE_JSON_TYPE = False + # Whether the `:` operator is used to extract a value from a JSON document + COLON_IS_JSON_EXTRACT = False + # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. # If this is True and '(' is not found, the keyword will be treated as an identifier VALUES_FOLLOWED_BY_PAREN = True @@ -1631,6 +1643,7 @@ class Parser(metaclass=_Parser): extend_props(self._parse_properties()) expression = self._match(TokenType.ALIAS) and self._parse_heredoc() + extend_props(self._parse_properties()) if not expression: if self._match(TokenType.COMMAND): @@ -4155,7 +4168,9 @@ class Parser(metaclass=_Parser): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: + def _parse_type( + self, parse_interval: bool = True, fallback_to_identifier: bool = False + ) -> t.Optional[exp.Expression]: interval = parse_interval and self._parse_interval() if interval: # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals @@ -4183,9 +4198,11 @@ class Parser(metaclass=_Parser): if parser: return parser(self, this, data_type) return self.expression(exp.Cast, this=this, to=data_type) + if not data_type.expressions: self._retreat(index) - return self._parse_column() + return self._parse_id_var() if fallback_to_identifier else self._parse_column() + return self._parse_column_ops(data_type) return this and self._parse_column_ops(this) @@ -4364,7 +4381,10 @@ class Parser(metaclass=_Parser): def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: index = self._index - this = self._parse_type(parse_interval=False) or self._parse_id_var() + this = ( + self._parse_type(parse_interval=False, fallback_to_identifier=True) + or self._parse_id_var() + ) self._match(TokenType.COLON) column_def = self._parse_column_def(this) @@ -4401,6 +4421,47 @@ class Parser(metaclass=_Parser): return this + def _parse_colon_as_json_extract( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + casts = [] + json_path = [] + + while self._match(TokenType.COLON): + start_index = self._index + path = self._parse_column_ops(self._parse_field(any_token=True)) + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the JSON path + while isinstance(path, exp.Cast): + casts.append(path.to) + path = path.this + + if casts: + dcolon_offset = next( + i + for i, t in enumerate(self._tokens[start_index:]) + if t.token_type == TokenType.DCOLON + ) + end_token = self._tokens[start_index + dcolon_offset - 1] + else: + end_token = self._prev + + if path: + json_path.append(self._find_sql(self._tokens[start_index], end_token)) + + if json_path: + this = self.expression( + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))), + ) + + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) + + return this + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) @@ -4444,8 +4505,10 @@ class Parser(metaclass=_Parser): ) else: this = self.expression(exp.Dot, this=this, expression=field) + this = self._parse_bracket(this) - return this + + return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this def _parse_primary(self) -> t.Optional[exp.Expression]: if self._match_set(self.PRIMARY_PARSERS): @@ -4680,18 +4743,21 @@ class Parser(metaclass=_Parser): return self.expression(exp.SessionParameter, this=this, kind=kind) + def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: + return self._parse_id_var() + def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): expressions = t.cast( - t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var) + t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_lambda_arg) ) if not self._match(TokenType.R_PAREN): self._retreat(index) else: - expressions = [self._parse_id_var()] + expressions = [self._parse_lambda_arg()] if self._match_set(self.LAMBDAS): return self.LAMBDAS[self._prev.token_type](self, expressions) @@ -6182,8 +6248,10 @@ class Parser(metaclass=_Parser): return None right = self._parse_statement() or self._parse_id_var() - this = self.expression(exp.EQ, this=left, expression=right) + if isinstance(right, (exp.Column, exp.Identifier)): + right = exp.var(right.name) + this = self.expression(exp.EQ, this=left, expression=right) return self.expression(exp.SetItem, this=this, kind=kind) def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: @@ -6433,14 +6501,25 @@ class Parser(metaclass=_Parser): return True def _replace_lambda( - self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str] + self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression] ) -> t.Optional[exp.Expression]: if not node: return node + lambda_types = {e.name: e.args.get("to") or False for e in expressions} + for column in node.find_all(exp.Column): - if column.parts[0].name in lambda_variables: + typ = lambda_types.get(column.parts[0].name) + if typ is not None: dot_or_id = column.to_dot() if column.table else column.this + + if typ: + dot_or_id = self.expression( + exp.Cast, + this=dot_or_id, + to=typ, + ) + parent = column.parent while isinstance(parent, exp.Dot): |