diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-07 12:35:01 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-07 12:35:01 +0000 |
commit | 58c11f70074708344e433080e47621091a6dcd84 (patch) | |
tree | 2589166e0e58be4947e07a956d26efa497bccaf2 /sqlglot/parser.py | |
parent | Adding upstream version 11.4.5. (diff) | |
download | sqlglot-58c11f70074708344e433080e47621091a6dcd84.tar.xz sqlglot-58c11f70074708344e433080e47621091a6dcd84.zip |
Adding upstream version 11.5.2.upstream/11.5.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 96 |
1 files changed, 94 insertions, 2 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 8269525..b3b899c 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -105,6 +105,7 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_DATETIME: exp.CurrentDate, TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, + TokenType.CURRENT_USER: exp.CurrentUser, } NESTED_TYPE_TOKENS = { @@ -285,6 +286,7 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, + TokenType.CURRENT_USER, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -674,9 +676,11 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), + "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), "LOG": lambda self: self._parse_logarithm(), + "MATCH": lambda self: self._parse_match_against(), "POSITION": lambda self: self._parse_position(), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), @@ -2634,7 +2638,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() maybe_func = True - if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): this = exp.DataType( this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value, expressions=expressions)], @@ -2959,6 +2963,11 @@ class Parser(metaclass=_Parser): else: this = self._parse_select_or_expression() + if isinstance(this, exp.EQ): + left = this.this + if isinstance(left, exp.Column): + left.replace(exp.Var(this=left.text("this"))) + if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) else: @@ -2968,8 +2977,16 @@ class Parser(metaclass=_Parser): def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index - if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): + + try: + if self._parse_select(nested=True): + return this + except Exception: + pass + finally: self._retreat(index) + + if not self._match(TokenType.L_PAREN): return this args = self._parse_csv( @@ -3344,6 +3361,51 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_decode(self) -> t.Optional[exp.Expression]: + """ + There are generally two variants of the DECODE function: + + - DECODE(bin, charset) + - DECODE(expression, search, result [, search, result] ... [, default]) + + The second variant will always be parsed into a CASE expression. Note that NULL + needs special treatment, since we need to explicitly check for it with `IS NULL`, + instead of relying on pattern matching. + """ + args = self._parse_csv(self._parse_conjunction) + + if len(args) < 3: + return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1)) + + expression, *expressions = args + if not expression: + return None + + ifs = [] + for search, result in zip(expressions[::2], expressions[1::2]): + if not search or not result: + return None + + if isinstance(search, exp.Literal): + ifs.append( + exp.If(this=exp.EQ(this=expression.copy(), expression=search), true=result) + ) + elif isinstance(search, exp.Null): + ifs.append( + exp.If(this=exp.Is(this=expression.copy(), expression=exp.Null()), true=result) + ) + else: + cond = exp.or_( + exp.EQ(this=expression.copy(), expression=search), + exp.and_( + exp.Is(this=expression.copy(), expression=exp.Null()), + exp.Is(this=search.copy(), expression=exp.Null()), + ), + ) + ifs.append(exp.If(this=cond, true=result)) + + return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None) + def _parse_json_key_value(self) -> t.Optional[exp.Expression]: self._match_text_seq("KEY") key = self._parse_field() @@ -3398,6 +3460,28 @@ class Parser(metaclass=_Parser): exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) ) + def _parse_match_against(self) -> exp.Expression: + expressions = self._parse_csv(self._parse_column) + + self._match_text_seq(")", "AGAINST", "(") + + this = self._parse_string() + + if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"): + modifier = "IN NATURAL LANGUAGE MODE" + if self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = f"{modifier} WITH QUERY EXPANSION" + elif self._match_text_seq("IN", "BOOLEAN", "MODE"): + modifier = "IN BOOLEAN MODE" + elif self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = "WITH QUERY EXPANSION" + else: + modifier = None + + return self.expression( + exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier + ) + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) @@ -3791,6 +3875,14 @@ class Parser(metaclass=_Parser): if expression: expression.set("exists", exists_column) + # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns + if self._match_texts(("FIRST", "AFTER")): + position = self._prev.text + column_position = self.expression( + exp.ColumnPosition, this=self._parse_column(), position=position + ) + expression.set("position", column_position) + return expression def _parse_drop_column(self) -> t.Optional[exp.Expression]: |