summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:39 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:39 +0000
commitaedf35026379f52d7e2b4c1f957691410a758089 (patch)
tree86540364259b66741173d2333387b78d6f9c31e2 /sqlglot/parser.py
parentAdding upstream version 20.11.0. (diff)
downloadsqlglot-aedf35026379f52d7e2b4c1f957691410a758089.tar.xz
sqlglot-aedf35026379f52d7e2b4c1f957691410a758089.zip
Adding upstream version 21.0.1.upstream/21.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py99
1 files changed, 71 insertions, 28 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index c091605..a89e4fa 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
+def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
+ def _parser(args: t.List, dialect: Dialect) -> E:
+ expression = expr_type(
+ this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
+ )
+ if len(args) > 2 and expr_type is exp.JSONExtract:
+ expression.set("expressions", args[2:])
+
+ return expression
+
+ return _parser
+
+
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@@ -102,6 +115,9 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
+ "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract),
+ "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar),
+ "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar),
"LIKE": parse_like,
"LOG": parse_logarithm,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
@@ -175,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.NCHAR,
TokenType.VARCHAR,
TokenType.NVARCHAR,
+ TokenType.BPCHAR,
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
@@ -295,6 +312,7 @@ class Parser(metaclass=_Parser):
TokenType.ASC,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
+ TokenType.BPCHAR,
TokenType.CACHE,
TokenType.CASE,
TokenType.COLLATE,
@@ -531,12 +549,12 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
- expression=path,
+ expression=self.dialect.to_json_path(path),
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
- expression=path,
+ expression=self.dialect.to_json_path(path),
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
@@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser):
exp.Drop,
comments=start.comments,
exists=exists or self._parse_exists(),
- this=self._parse_table(schema=True),
+ this=self._parse_table(
+ schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
+ ),
kind=kind,
temporary=temporary,
materialized=materialized,
@@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser):
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
- table_parts = self._parse_table_parts(schema=True)
+ table_parts = self._parse_table_parts(
+ schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
+ )
# exp.Properties.Location.POST_NAME
self._match(TokenType.COMMA)
@@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
- text += f" SHOW EMPTY MATCHES"
+ text += " SHOW EMPTY MATCHES"
elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"):
- text += f" OMIT EMPTY MATCHES"
+ text += " OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
- text += f" WITH UNMATCHED ROWS"
+ text += " WITH UNMATCHED ROWS"
rows = exp.var(text)
else:
rows = None
@@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("AFTER", "MATCH", "SKIP"):
text = "AFTER MATCH SKIP"
if self._match_text_seq("PAST", "LAST", "ROW"):
- text += f" PAST LAST ROW"
+ text += " PAST LAST ROW"
elif self._match_text_seq("TO", "NEXT", "ROW"):
- text += f" TO NEXT ROW"
+ text += " TO NEXT ROW"
elif self._match_text_seq("TO", "FIRST"):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
@@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser):
or self._parse_placeholder()
)
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
catalog = None
db = None
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
@@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser):
db = table
table = self._parse_table_part(schema=schema) or ""
- if not table:
+ if is_db_reference:
+ catalog = db
+ db = table
+ table = None
+
+ if not table and not is_db_reference:
self.raise_error(f"Expected table name but got {self._curr}")
+ if not db and is_db_reference:
+ self.raise_error(f"Expected database name but got {self._curr}")
return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
+ is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser):
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this = t.cast(
- exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
+ exp.Expression,
+ bracket
+ or self._parse_bracket(
+ self._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
+ ),
)
if schema:
@@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser):
identifier = allow_identifiers and self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
-
if identifier:
tokens = self.dialect.tokenize(identifier.name)
@@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self) -> t.Optional[exp.Expression]:
+ this = self._parse_column_reference()
+ return self._parse_column_ops(this) if this else self._parse_bracket(this)
+
+ def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
- elif not this:
- return self._parse_bracket(this)
- return self._parse_column_ops(this)
+ return this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser):
if not field:
self.raise_error("Expected type")
elif op and self._curr:
- self._advance()
- value = self._prev.text
- field = (
- exp.Literal.number(value)
- if self._prev.token_type == TokenType.NUMBER
- else exp.Literal.string(value)
- )
+ field = self._parse_column_reference()
else:
field = self._parse_field(anonymous_func=True, any_token=True)
@@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser):
options[kind] = action
return self.expression(
- exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
+ exp.ForeignKey,
+ expressions=expressions,
+ reference=reference,
+ **options, # type: ignore
)
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
@@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser):
return None
@t.overload
- def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
+ ...
@t.overload
- def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
+ ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser):
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
+ if isinstance(this, exp.AggFunc):
+ ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls)
+
+ if ignore_respect and ignore_respect is not this:
+ ignore_respect.replace(ignore_respect.this)
+ this = self.expression(ignore_respect.__class__, this=this)
+
this = self._parse_respect_or_ignore_nulls(this)
# bigquery select from window x AS (partition by ...)
@@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser):
return True
@t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
+ ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]: ...
+ ) -> t.Optional[exp.Expression]:
+ ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):