summaryrefslogtreecommitdiffstats
path: root/sphinx/pycode/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r--sphinx/pycode/ast.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py
new file mode 100644
index 0000000..e5914cc
--- /dev/null
+++ b/sphinx/pycode/ast.py
@@ -0,0 +1,188 @@
+"""Helpers for AST (Abstract Syntax Tree)."""
+
+from __future__ import annotations
+
+import ast
+from typing import overload
+
+OPERATORS: dict[type[ast.AST], str] = {
+ ast.Add: "+",
+ ast.And: "and",
+ ast.BitAnd: "&",
+ ast.BitOr: "|",
+ ast.BitXor: "^",
+ ast.Div: "/",
+ ast.FloorDiv: "//",
+ ast.Invert: "~",
+ ast.LShift: "<<",
+ ast.MatMult: "@",
+ ast.Mult: "*",
+ ast.Mod: "%",
+ ast.Not: "not",
+ ast.Pow: "**",
+ ast.Or: "or",
+ ast.RShift: ">>",
+ ast.Sub: "-",
+ ast.UAdd: "+",
+ ast.USub: "-",
+}
+
+
+@overload
+def unparse(node: None, code: str = '') -> None:
+ ...
+
+
+@overload
+def unparse(node: ast.AST, code: str = '') -> str:
+ ...
+
+
+def unparse(node: ast.AST | None, code: str = '') -> str | None:
+ """Unparse an AST to string."""
+ if node is None:
+ return None
+ elif isinstance(node, str):
+ return node
+ return _UnparseVisitor(code).visit(node)
+
+
+# a greatly cut-down version of `ast._Unparser`
+class _UnparseVisitor(ast.NodeVisitor):
+ def __init__(self, code: str = '') -> None:
+ self.code = code
+
+ def _visit_op(self, node: ast.AST) -> str:
+ return OPERATORS[node.__class__]
+ for _op in OPERATORS:
+ locals()[f'visit_{_op.__name__}'] = _visit_op
+
+ def visit_arg(self, node: ast.arg) -> str:
+ if node.annotation:
+ return f"{node.arg}: {self.visit(node.annotation)}"
+ else:
+ return node.arg
+
+ def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str:
+ """Unparse a single argument to a string."""
+ name = self.visit(arg)
+ if default:
+ if arg.annotation:
+ name += " = %s" % self.visit(default)
+ else:
+ name += "=%s" % self.visit(default)
+ return name
+
+ def visit_arguments(self, node: ast.arguments) -> str:
+ defaults: list[ast.expr | None] = list(node.defaults)
+ positionals = len(node.args)
+ posonlyargs = len(node.posonlyargs)
+ positionals += posonlyargs
+ for _ in range(len(defaults), positionals):
+ defaults.insert(0, None)
+
+ kw_defaults: list[ast.expr | None] = list(node.kw_defaults)
+ for _ in range(len(kw_defaults), len(node.kwonlyargs)):
+ kw_defaults.insert(0, None)
+
+ args: list[str] = []
+ for i, arg in enumerate(node.posonlyargs):
+ args.append(self._visit_arg_with_default(arg, defaults[i]))
+
+ if node.posonlyargs:
+ args.append('/')
+
+ for i, arg in enumerate(node.args):
+ args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
+
+ if node.vararg:
+ args.append("*" + self.visit(node.vararg))
+
+ if node.kwonlyargs and not node.vararg:
+ args.append('*')
+ for i, arg in enumerate(node.kwonlyargs):
+ args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
+
+ if node.kwarg:
+ args.append("**" + self.visit(node.kwarg))
+
+ return ", ".join(args)
+
+ def visit_Attribute(self, node: ast.Attribute) -> str:
+ return f"{self.visit(node.value)}.{node.attr}"
+
+ def visit_BinOp(self, node: ast.BinOp) -> str:
+ # Special case ``**`` to not have surrounding spaces.
+ if isinstance(node.op, ast.Pow):
+ return "".join(map(self.visit, (node.left, node.op, node.right)))
+ return " ".join(self.visit(e) for e in [node.left, node.op, node.right])
+
+ def visit_BoolOp(self, node: ast.BoolOp) -> str:
+ op = " %s " % self.visit(node.op)
+ return op.join(self.visit(e) for e in node.values)
+
+ def visit_Call(self, node: ast.Call) -> str:
+ args = ', '.join([self.visit(e) for e in node.args]
+ + [f"{k.arg}={self.visit(k.value)}" for k in node.keywords])
+ return f"{self.visit(node.func)}({args})"
+
+ def visit_Constant(self, node: ast.Constant) -> str:
+ if node.value is Ellipsis:
+ return "..."
+ elif isinstance(node.value, (int, float, complex)):
+ if self.code:
+ return ast.get_source_segment(self.code, node) or repr(node.value)
+ else:
+ return repr(node.value)
+ else:
+ return repr(node.value)
+
+ def visit_Dict(self, node: ast.Dict) -> str:
+ keys = (self.visit(k) for k in node.keys if k is not None)
+ values = (self.visit(v) for v in node.values)
+ items = (k + ": " + v for k, v in zip(keys, values))
+ return "{" + ", ".join(items) + "}"
+
+ def visit_Lambda(self, node: ast.Lambda) -> str:
+ return "lambda %s: ..." % self.visit(node.args)
+
+ def visit_List(self, node: ast.List) -> str:
+ return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
+
+ def visit_Name(self, node: ast.Name) -> str:
+ return node.id
+
+ def visit_Set(self, node: ast.Set) -> str:
+ return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
+
+ def visit_Subscript(self, node: ast.Subscript) -> str:
+ def is_simple_tuple(value: ast.expr) -> bool:
+ return (
+ isinstance(value, ast.Tuple)
+ and bool(value.elts)
+ and not any(isinstance(elt, ast.Starred) for elt in value.elts)
+ )
+
+ if is_simple_tuple(node.slice):
+ elts = ", ".join(self.visit(e)
+ for e in node.slice.elts) # type: ignore[attr-defined]
+ return f"{self.visit(node.value)}[{elts}]"
+ return f"{self.visit(node.value)}[{self.visit(node.slice)}]"
+
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
+ # UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``,
+ # ``-x``, ``~x``, and ``not x``. Only Not needs a space.
+ if isinstance(node.op, ast.Not):
+ return f"{self.visit(node.op)} {self.visit(node.operand)}"
+ return f"{self.visit(node.op)}{self.visit(node.operand)}"
+
+ def visit_Tuple(self, node: ast.Tuple) -> str:
+ if len(node.elts) == 0:
+ return "()"
+ elif len(node.elts) == 1:
+ return "(%s,)" % self.visit(node.elts[0])
+ else:
+ return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"
+
+ def generic_visit(self, node):
+ raise NotImplementedError('Unable to parse %s object' % type(node).__name__)