diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 17:25:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 17:25:40 +0000 |
commit | cf7da1843c45a4c2df7a749f7886a2d2ba0ee92a (patch) | |
tree | 18dcde1a8d1f5570a77cd0c361de3b490d02c789 /sphinx/pycode/ast.py | |
parent | Initial commit. (diff) | |
download | sphinx-upstream/7.2.6.tar.xz sphinx-upstream/7.2.6.zip |
Adding upstream version 7.2.6.upstream/7.2.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r-- | sphinx/pycode/ast.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py new file mode 100644 index 0000000..e5914cc --- /dev/null +++ b/sphinx/pycode/ast.py @@ -0,0 +1,188 @@ +"""Helpers for AST (Abstract Syntax Tree).""" + +from __future__ import annotations + +import ast +from typing import overload + +OPERATORS: dict[type[ast.AST], str] = { + ast.Add: "+", + ast.And: "and", + ast.BitAnd: "&", + ast.BitOr: "|", + ast.BitXor: "^", + ast.Div: "/", + ast.FloorDiv: "//", + ast.Invert: "~", + ast.LShift: "<<", + ast.MatMult: "@", + ast.Mult: "*", + ast.Mod: "%", + ast.Not: "not", + ast.Pow: "**", + ast.Or: "or", + ast.RShift: ">>", + ast.Sub: "-", + ast.UAdd: "+", + ast.USub: "-", +} + + +@overload +def unparse(node: None, code: str = '') -> None: + ... + + +@overload +def unparse(node: ast.AST, code: str = '') -> str: + ... + + +def unparse(node: ast.AST | None, code: str = '') -> str | None: + """Unparse an AST to string.""" + if node is None: + return None + elif isinstance(node, str): + return node + return _UnparseVisitor(code).visit(node) + + +# a greatly cut-down version of `ast._Unparser` +class _UnparseVisitor(ast.NodeVisitor): + def __init__(self, code: str = '') -> None: + self.code = code + + def _visit_op(self, node: ast.AST) -> str: + return OPERATORS[node.__class__] + for _op in OPERATORS: + locals()[f'visit_{_op.__name__}'] = _visit_op + + def visit_arg(self, node: ast.arg) -> str: + if node.annotation: + return f"{node.arg}: {self.visit(node.annotation)}" + else: + return node.arg + + def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str: + """Unparse a single argument to a string.""" + name = self.visit(arg) + if default: + if arg.annotation: + name += " = %s" % self.visit(default) + else: + name += "=%s" % self.visit(default) + return name + + def visit_arguments(self, node: ast.arguments) -> str: + defaults: list[ast.expr | None] = list(node.defaults) + positionals = len(node.args) + posonlyargs = len(node.posonlyargs) + positionals += posonlyargs + for _ in range(len(defaults), positionals): + defaults.insert(0, None) + + kw_defaults: list[ast.expr | None] = list(node.kw_defaults) + for _ in range(len(kw_defaults), len(node.kwonlyargs)): + kw_defaults.insert(0, None) + + args: list[str] = [] + for i, arg in enumerate(node.posonlyargs): + args.append(self._visit_arg_with_default(arg, defaults[i])) + + if node.posonlyargs: + args.append('/') + + for i, arg in enumerate(node.args): + args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs])) + + if node.vararg: + args.append("*" + self.visit(node.vararg)) + + if node.kwonlyargs and not node.vararg: + args.append('*') + for i, arg in enumerate(node.kwonlyargs): + args.append(self._visit_arg_with_default(arg, kw_defaults[i])) + + if node.kwarg: + args.append("**" + self.visit(node.kwarg)) + + return ", ".join(args) + + def visit_Attribute(self, node: ast.Attribute) -> str: + return f"{self.visit(node.value)}.{node.attr}" + + def visit_BinOp(self, node: ast.BinOp) -> str: + # Special case ``**`` to not have surrounding spaces. + if isinstance(node.op, ast.Pow): + return "".join(map(self.visit, (node.left, node.op, node.right))) + return " ".join(self.visit(e) for e in [node.left, node.op, node.right]) + + def visit_BoolOp(self, node: ast.BoolOp) -> str: + op = " %s " % self.visit(node.op) + return op.join(self.visit(e) for e in node.values) + + def visit_Call(self, node: ast.Call) -> str: + args = ', '.join([self.visit(e) for e in node.args] + + [f"{k.arg}={self.visit(k.value)}" for k in node.keywords]) + return f"{self.visit(node.func)}({args})" + + def visit_Constant(self, node: ast.Constant) -> str: + if node.value is Ellipsis: + return "..." + elif isinstance(node.value, (int, float, complex)): + if self.code: + return ast.get_source_segment(self.code, node) or repr(node.value) + else: + return repr(node.value) + else: + return repr(node.value) + + def visit_Dict(self, node: ast.Dict) -> str: + keys = (self.visit(k) for k in node.keys if k is not None) + values = (self.visit(v) for v in node.values) + items = (k + ": " + v for k, v in zip(keys, values)) + return "{" + ", ".join(items) + "}" + + def visit_Lambda(self, node: ast.Lambda) -> str: + return "lambda %s: ..." % self.visit(node.args) + + def visit_List(self, node: ast.List) -> str: + return "[" + ", ".join(self.visit(e) for e in node.elts) + "]" + + def visit_Name(self, node: ast.Name) -> str: + return node.id + + def visit_Set(self, node: ast.Set) -> str: + return "{" + ", ".join(self.visit(e) for e in node.elts) + "}" + + def visit_Subscript(self, node: ast.Subscript) -> str: + def is_simple_tuple(value: ast.expr) -> bool: + return ( + isinstance(value, ast.Tuple) + and bool(value.elts) + and not any(isinstance(elt, ast.Starred) for elt in value.elts) + ) + + if is_simple_tuple(node.slice): + elts = ", ".join(self.visit(e) + for e in node.slice.elts) # type: ignore[attr-defined] + return f"{self.visit(node.value)}[{elts}]" + return f"{self.visit(node.value)}[{self.visit(node.slice)}]" + + def visit_UnaryOp(self, node: ast.UnaryOp) -> str: + # UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``, + # ``-x``, ``~x``, and ``not x``. Only Not needs a space. + if isinstance(node.op, ast.Not): + return f"{self.visit(node.op)} {self.visit(node.operand)}" + return f"{self.visit(node.op)}{self.visit(node.operand)}" + + def visit_Tuple(self, node: ast.Tuple) -> str: + if len(node.elts) == 0: + return "()" + elif len(node.elts) == 1: + return "(%s,)" % self.visit(node.elts[0]) + else: + return "(" + ", ".join(self.visit(e) for e in node.elts) + ")" + + def generic_visit(self, node): + raise NotImplementedError('Unable to parse %s object' % type(node).__name__) |