summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-23 07:22:20 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-23 07:22:20 +0000
commit41e67f6ce6b4b732d02e421d6825c18b8d15a59d (patch)
tree30fb0000d3e6ff11b366567bc35564842e7dbb50 /sqlglot/parser.py
parentAdding upstream version 23.16.0. (diff)
downloadsqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.tar.xz
sqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.zip
Adding upstream version 24.0.0.upstream/24.0.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py107
1 files changed, 93 insertions, 14 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 67ffd8f..3237cd1 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -41,11 +41,17 @@ def build_like(args: t.List) -> exp.Escape | exp.Like:
def binary_range_parser(
- expr_type: t.Type[exp.Expression],
+ expr_type: t.Type[exp.Expression], reverse_args: bool = False
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
- return lambda self, this: self._parse_escape(
- self.expression(expr_type, this=this, expression=self._parse_bitwise())
- )
+ def _parse_binary_range(
+ self: Parser, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ expression = self._parse_bitwise()
+ if reverse_args:
+ this, expression = expression, this
+ return self._parse_escape(self.expression(expr_type, this=this, expression=expression))
+
+ return _parse_binary_range
def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
@@ -335,6 +341,8 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TAG,
TokenType.VIEW,
+ TokenType.WAREHOUSE,
+ TokenType.STREAMLIT,
}
CREATABLES = {
@@ -418,6 +426,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.TRUNCATE,
TokenType.UNIQUE,
+ TokenType.UNNEST,
TokenType.UNPIVOT,
TokenType.UPDATE,
TokenType.USE,
@@ -580,7 +589,7 @@ class Parser(metaclass=_Parser):
exp.Lambda,
this=self._replace_lambda(
self._parse_conjunction(),
- {node.name for node in expressions},
+ expressions,
),
expressions=expressions,
),
@@ -1160,6 +1169,9 @@ class Parser(metaclass=_Parser):
# Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres)
JSON_ARROWS_REQUIRE_JSON_TYPE = False
+ # Whether the `:` operator is used to extract a value from a JSON document
+ COLON_IS_JSON_EXTRACT = False
+
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
# If this is True and '(' is not found, the keyword will be treated as an identifier
VALUES_FOLLOWED_BY_PAREN = True
@@ -1631,6 +1643,7 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
expression = self._match(TokenType.ALIAS) and self._parse_heredoc()
+ extend_props(self._parse_properties())
if not expression:
if self._match(TokenType.COMMAND):
@@ -4155,7 +4168,9 @@ class Parser(metaclass=_Parser):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
- def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
+ def _parse_type(
+ self, parse_interval: bool = True, fallback_to_identifier: bool = False
+ ) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
@@ -4183,9 +4198,11 @@ class Parser(metaclass=_Parser):
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
+
if not data_type.expressions:
self._retreat(index)
- return self._parse_column()
+ return self._parse_id_var() if fallback_to_identifier else self._parse_column()
+
return self._parse_column_ops(data_type)
return this and self._parse_column_ops(this)
@@ -4364,7 +4381,10 @@ class Parser(metaclass=_Parser):
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
index = self._index
- this = self._parse_type(parse_interval=False) or self._parse_id_var()
+ this = (
+ self._parse_type(parse_interval=False, fallback_to_identifier=True)
+ or self._parse_id_var()
+ )
self._match(TokenType.COLON)
column_def = self._parse_column_def(this)
@@ -4401,6 +4421,47 @@ class Parser(metaclass=_Parser):
return this
+ def _parse_colon_as_json_extract(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ casts = []
+ json_path = []
+
+ while self._match(TokenType.COLON):
+ start_index = self._index
+ path = self._parse_column_ops(self._parse_field(any_token=True))
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the JSON path
+ while isinstance(path, exp.Cast):
+ casts.append(path.to)
+ path = path.this
+
+ if casts:
+ dcolon_offset = next(
+ i
+ for i, t in enumerate(self._tokens[start_index:])
+ if t.token_type == TokenType.DCOLON
+ )
+ end_token = self._tokens[start_index + dcolon_offset - 1]
+ else:
+ end_token = self._prev
+
+ if path:
+ json_path.append(self._find_sql(self._tokens[start_index], end_token))
+
+ if json_path:
+ this = self.expression(
+ exp.JSONExtract,
+ this=this,
+ expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
+ )
+
+ while casts:
+ this = self.expression(exp.Cast, this=this, to=casts.pop())
+
+ return this
+
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@@ -4444,8 +4505,10 @@ class Parser(metaclass=_Parser):
)
else:
this = self.expression(exp.Dot, this=this, expression=field)
+
this = self._parse_bracket(this)
- return this
+
+ return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this
def _parse_primary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PRIMARY_PARSERS):
@@ -4680,18 +4743,21 @@ class Parser(metaclass=_Parser):
return self.expression(exp.SessionParameter, this=this, kind=kind)
+ def _parse_lambda_arg(self) -> t.Optional[exp.Expression]:
+ return self._parse_id_var()
+
def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
expressions = t.cast(
- t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
+ t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_lambda_arg)
)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
else:
- expressions = [self._parse_id_var()]
+ expressions = [self._parse_lambda_arg()]
if self._match_set(self.LAMBDAS):
return self.LAMBDAS[self._prev.token_type](self, expressions)
@@ -6182,8 +6248,10 @@ class Parser(metaclass=_Parser):
return None
right = self._parse_statement() or self._parse_id_var()
- this = self.expression(exp.EQ, this=left, expression=right)
+ if isinstance(right, (exp.Column, exp.Identifier)):
+ right = exp.var(right.name)
+ this = self.expression(exp.EQ, this=left, expression=right)
return self.expression(exp.SetItem, this=this, kind=kind)
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
@@ -6433,14 +6501,25 @@ class Parser(metaclass=_Parser):
return True
def _replace_lambda(
- self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str]
+ self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression]
) -> t.Optional[exp.Expression]:
if not node:
return node
+ lambda_types = {e.name: e.args.get("to") or False for e in expressions}
+
for column in node.find_all(exp.Column):
- if column.parts[0].name in lambda_variables:
+ typ = lambda_types.get(column.parts[0].name)
+ if typ is not None:
dot_or_id = column.to_dot() if column.table else column.this
+
+ if typ:
+ dot_or_id = self.expression(
+ exp.Cast,
+ this=dot_or_id,
+ to=typ,
+ )
+
parent = column.parent
while isinstance(parent, exp.Dot):