diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:39 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:39 +0000 |
commit | aedf35026379f52d7e2b4c1f957691410a758089 (patch) | |
tree | 86540364259b66741173d2333387b78d6f9c31e2 /sqlglot/parser.py | |
parent | Adding upstream version 20.11.0. (diff) | |
download | sqlglot-aedf35026379f52d7e2b4c1f957691410a758089.tar.xz sqlglot-aedf35026379f52d7e2b4c1f957691410a758089.zip |
Adding upstream version 21.0.1.upstream/21.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 99 |
1 files changed, 71 insertions, 28 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c091605..a89e4fa 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) +def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: + def _parser(args: t.List, dialect: Dialect) -> E: + expression = expr_type( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ) + if len(args) > 2 and expr_type is exp.JSONExtract: + expression.set("expressions", args[2:]) + + return expression + + return _parser + + class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -102,6 +115,9 @@ class Parser(metaclass=_Parser): to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), + "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar), + "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar), "LIKE": parse_like, "LOG": parse_logarithm, "TIME_TO_TIME_STR": lambda args: exp.Cast( @@ -175,6 +191,7 @@ class Parser(metaclass=_Parser): TokenType.NCHAR, TokenType.VARCHAR, TokenType.NVARCHAR, + TokenType.BPCHAR, TokenType.TEXT, TokenType.MEDIUMTEXT, TokenType.LONGTEXT, @@ -295,6 +312,7 @@ class Parser(metaclass=_Parser): TokenType.ASC, TokenType.AUTO_INCREMENT, TokenType.BEGIN, + TokenType.BPCHAR, TokenType.CACHE, TokenType.CASE, TokenType.COLLATE, @@ -531,12 +549,12 @@ class Parser(metaclass=_Parser): TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, - expression=path, + expression=self.dialect.to_json_path(path), ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, - expression=path, + expression=self.dialect.to_json_path(path), ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, @@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser): exp.Drop, comments=start.comments, exists=exists or self._parse_exists(), - this=self._parse_table(schema=True), + this=self._parse_table( + schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA + ), kind=kind, temporary=temporary, materialized=materialized, @@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser): elif create_token.token_type == TokenType.INDEX: this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: - table_parts = self._parse_table_parts(schema=True) + table_parts = self._parse_table_parts( + schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA + ) # exp.Properties.Location.POST_NAME self._match(TokenType.COMMA) @@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): text = "ALL ROWS PER MATCH" if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): - text += f" SHOW EMPTY MATCHES" + text += " SHOW EMPTY MATCHES" elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"): - text += f" OMIT EMPTY MATCHES" + text += " OMIT EMPTY MATCHES" elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): - text += f" WITH UNMATCHED ROWS" + text += " WITH UNMATCHED ROWS" rows = exp.var(text) else: rows = None @@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser): if self._match_text_seq("AFTER", "MATCH", "SKIP"): text = "AFTER MATCH SKIP" if self._match_text_seq("PAST", "LAST", "ROW"): - text += f" PAST LAST ROW" + text += " PAST LAST ROW" elif self._match_text_seq("TO", "NEXT", "ROW"): - text += f" TO NEXT ROW" + text += " TO NEXT ROW" elif self._match_text_seq("TO", "FIRST"): text += f" TO FIRST {self._advance_any().text}" # type: ignore elif self._match_text_seq("TO", "LAST"): @@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser): or self._parse_placeholder() ) - def _parse_table_parts(self, schema: bool = False) -> exp.Table: + def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table: catalog = None db = None table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) @@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser): db = table table = self._parse_table_part(schema=schema) or "" - if not table: + if is_db_reference: + catalog = db + db = table + table = None + + if not table and not is_db_reference: self.raise_error(f"Expected table name but got {self._curr}") + if not db and is_db_reference: + self.raise_error(f"Expected database name but got {self._curr}") return self.expression( exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() @@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser): joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, + is_db_reference: bool = False, ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser): bracket = parse_bracket and self._parse_bracket(None) bracket = self.expression(exp.Table, this=bracket) if bracket else None this = t.cast( - exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema)) + exp.Expression, + bracket + or self._parse_bracket( + self._parse_table_parts(schema=schema, is_db_reference=is_db_reference) + ), ) if schema: @@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser): identifier = allow_identifiers and self._parse_id_var( any_token=False, tokens=(TokenType.VAR,) ) - if identifier: tokens = self.dialect.tokenize(identifier.name) @@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser): return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) def _parse_column(self) -> t.Optional[exp.Expression]: + this = self._parse_column_reference() + return self._parse_column_ops(this) if this else self._parse_bracket(this) + + def _parse_column_reference(self) -> t.Optional[exp.Expression]: this = self._parse_field() if isinstance(this, exp.Identifier): this = self.expression(exp.Column, this=this) - elif not this: - return self._parse_bracket(this) - return self._parse_column_ops(this) + return this def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) @@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser): if not field: self.raise_error("Expected type") elif op and self._curr: - self._advance() - value = self._prev.text - field = ( - exp.Literal.number(value) - if self._prev.token_type == TokenType.NUMBER - else exp.Literal.string(value) - ) + field = self._parse_column_reference() else: field = self._parse_field(anonymous_func=True, any_token=True) @@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser): options[kind] = action return self.expression( - exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore + exp.ForeignKey, + expressions=expressions, + reference=reference, + **options, # type: ignore ) def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: @@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser): return None @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: + ... @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: + ... def _parse_json_object(self, agg=False): star = self._parse_star() @@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser): # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) # and Snowflake chose to do the same for familiarity # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if isinstance(this, exp.AggFunc): + ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls) + + if ignore_respect and ignore_respect is not this: + ignore_respect.replace(ignore_respect.this) + this = self.expression(ignore_respect.__class__, this=this) + this = self._parse_respect_or_ignore_nulls(this) # bigquery select from window x AS (partition by ...) @@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser): return True @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: + ... @t.overload def _replace_columns_with_dots( self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: ... + ) -> t.Optional[exp.Expression]: + ... def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): |