From c66e4a33e1a07c439f03fe47f146a6c6482bf6df Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 8 Feb 2024 06:38:42 +0100 Subject: Merging upstream version 21.0.1. Signed-off-by: Daniel Baumann --- sqlglot/jsonpath.py | 132 +++++++++++++++++++++++----------------------------- 1 file changed, 57 insertions(+), 75 deletions(-) (limited to 'sqlglot/jsonpath.py') diff --git a/sqlglot/jsonpath.py b/sqlglot/jsonpath.py index c410d11..129a4e6 100644 --- a/sqlglot/jsonpath.py +++ b/sqlglot/jsonpath.py @@ -2,8 +2,8 @@ from __future__ import annotations import typing as t +import sqlglot.expressions as exp from sqlglot.errors import ParseError -from sqlglot.expressions import SAFE_IDENTIFIER_RE from sqlglot.tokens import Token, Tokenizer, TokenType if t.TYPE_CHECKING: @@ -36,20 +36,8 @@ class JSONPathTokenizer(Tokenizer): STRING_ESCAPES = ["\\"] -JSONPathNode = t.Dict[str, t.Any] - - -def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode: - node = {"kind": kind, **kwargs} - - if value is not None: - node["value"] = value - - return node - - -def parse(path: str) -> t.List[JSONPathNode]: - """Takes in a JSONPath string and converts into a list of nodes.""" +def parse(path: str) -> exp.JSONPath: + """Takes in a JSON path string and parses it into a JSONPath expression.""" tokens = JSONPathTokenizer().tokenize(path) size = len(tokens) @@ -89,7 +77,7 @@ def parse(path: str) -> t.List[JSONPathNode]: if token: return token.text if _match(TokenType.STAR): - return _node("wildcard") + return exp.JSONPathWildcard() if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN): script = _prev().text == "(" start = i @@ -100,9 +88,9 @@ def parse(path: str) -> t.List[JSONPathNode]: if _curr() in (TokenType.R_BRACKET, None): break _advance() - return _node( - "script" if script else "filter", path[tokens[start].start : tokens[i].end] - ) + + expr_type = exp.JSONPathScript if script else exp.JSONPathFilter + return expr_type(this=path[tokens[start].start : tokens[i].end]) number = "-" if _match(TokenType.DASH) else "" @@ -112,6 +100,7 @@ def parse(path: str) -> t.List[JSONPathNode]: if number: return int(number) + return False def _parse_slice() -> t.Any: @@ -121,9 +110,10 @@ def parse(path: str) -> t.List[JSONPathNode]: if end is None and step is None: return start - return _node("slice", start=start, end=end, step=step) - def _parse_bracket() -> JSONPathNode: + return exp.JSONPathSlice(start=start, end=end, step=step) + + def _parse_bracket() -> exp.JSONPathPart: literal = _parse_slice() if isinstance(literal, str) or literal is not False: @@ -136,13 +126,15 @@ def parse(path: str) -> t.List[JSONPathNode]: if len(indexes) == 1: if isinstance(literal, str): - node = _node("key", indexes[0]) - elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"): - node = _node("selector", indexes[0]) + node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0]) + elif isinstance(literal, exp.JSONPathPart) and isinstance( + literal, (exp.JSONPathScript, exp.JSONPathFilter) + ): + node = exp.JSONPathSelector(this=indexes[0]) else: - node = _node("subscript", indexes[0]) + node = exp.JSONPathSubscript(this=indexes[0]) else: - node = _node("union", indexes) + node = exp.JSONPathUnion(expressions=indexes) else: raise ParseError(_error("Cannot have empty segment")) @@ -150,66 +142,56 @@ def parse(path: str) -> t.List[JSONPathNode]: return node - nodes = [] + # We canonicalize the JSON path AST so that it always starts with a + # "root" element, so paths like "field" will be generated as "$.field" + _match(TokenType.DOLLAR) + expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] while _curr(): - if _match(TokenType.DOLLAR): - nodes.append(_node("root")) - elif _match(TokenType.DOT): + if _match(TokenType.DOT) or _match(TokenType.COLON): recursive = _prev().text == ".." - value = _match(TokenType.VAR) or _match(TokenType.STAR) - nodes.append( - _node("recursive" if recursive else "child", value=value.text if value else None) - ) + + if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER): + value: t.Optional[str | exp.JSONPathWildcard] = _prev().text + elif _match(TokenType.STAR): + value = exp.JSONPathWildcard() + else: + value = None + + if recursive: + expressions.append(exp.JSONPathRecursive(this=value)) + elif value: + expressions.append(exp.JSONPathKey(this=value)) + else: + raise ParseError(_error("Expected key name or * after DOT")) elif _match(TokenType.L_BRACKET): - nodes.append(_parse_bracket()) - elif _match(TokenType.VAR): - nodes.append(_node("key", _prev().text)) + expressions.append(_parse_bracket()) + elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER): + expressions.append(exp.JSONPathKey(this=_prev().text)) elif _match(TokenType.STAR): - nodes.append(_node("wildcard")) - elif _match(TokenType.PARAMETER): - nodes.append(_node("current")) + expressions.append(exp.JSONPathWildcard()) else: raise ParseError(_error(f"Unexpected {tokens[i].token_type}")) - return nodes + return exp.JSONPath(expressions=expressions) -MAPPING = { - "child": lambda n: f".{n['value']}" if n.get("value") is not None else "", - "filter": lambda n: f"?{n['value']}", - "key": lambda n: ( - f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]' - ), - "recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..", - "root": lambda _: "$", - "script": lambda n: f"({n['value']}", - "slice": lambda n: ":".join( - "" if p is False else generate([p]) - for p in [n["start"], n["end"], n["step"]] +JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { + exp.JSONPathFilter: lambda _, e: f"?{e.this}", + exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e), + exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}", + exp.JSONPathRoot: lambda *_: "$", + exp.JSONPathScript: lambda _, e: f"({e.this}", + exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]", + exp.JSONPathSlice: lambda self, e: ":".join( + "" if p is False else self.json_path_part(p) + for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")] if p is not None ), - "selector": lambda n: f"[{generate([n['value']])}]", - "subscript": lambda n: f"[{generate([n['value']])}]", - "union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]", - "wildcard": lambda _: "*", + exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e), + exp.JSONPathUnion: lambda self, + e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]", + exp.JSONPathWildcard: lambda *_: "*", } - -def generate( - nodes: t.List[JSONPathNode], - mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None, -) -> str: - mapping = MAPPING if mapping is None else mapping - path = [] - - for node in nodes: - if isinstance(node, dict): - path.append(mapping[node["kind"]](node)) - elif isinstance(node, str): - escaped = node.replace('"', '\\"') - path.append(f'"{escaped}"') - else: - path.append(str(node)) - - return "".join(path) +ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS) -- cgit v1.2.3