"""Helpers for AST (Abstract Syntax Tree).""" from __future__ import annotations import ast from typing import overload OPERATORS: dict[type[ast.AST], str] = { ast.Add: "+", ast.And: "and", ast.BitAnd: "&", ast.BitOr: "|", ast.BitXor: "^", ast.Div: "/", ast.FloorDiv: "//", ast.Invert: "~", ast.LShift: "<<", ast.MatMult: "@", ast.Mult: "*", ast.Mod: "%", ast.Not: "not", ast.Pow: "**", ast.Or: "or", ast.RShift: ">>", ast.Sub: "-", ast.UAdd: "+", ast.USub: "-", } @overload def unparse(node: None, code: str = '') -> None: ... @overload def unparse(node: ast.AST, code: str = '') -> str: ... def unparse(node: ast.AST | None, code: str = '') -> str | None: """Unparse an AST to string.""" if node is None: return None elif isinstance(node, str): return node return _UnparseVisitor(code).visit(node) # a greatly cut-down version of `ast._Unparser` class _UnparseVisitor(ast.NodeVisitor): def __init__(self, code: str = '') -> None: self.code = code def _visit_op(self, node: ast.AST) -> str: return OPERATORS[node.__class__] for _op in OPERATORS: locals()[f'visit_{_op.__name__}'] = _visit_op def visit_arg(self, node: ast.arg) -> str: if node.annotation: return f"{node.arg}: {self.visit(node.annotation)}" else: return node.arg def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str: """Unparse a single argument to a string.""" name = self.visit(arg) if default: if arg.annotation: name += " = %s" % self.visit(default) else: name += "=%s" % self.visit(default) return name def visit_arguments(self, node: ast.arguments) -> str: defaults: list[ast.expr | None] = list(node.defaults) positionals = len(node.args) posonlyargs = len(node.posonlyargs) positionals += posonlyargs for _ in range(len(defaults), positionals): defaults.insert(0, None) kw_defaults: list[ast.expr | None] = list(node.kw_defaults) for _ in range(len(kw_defaults), len(node.kwonlyargs)): kw_defaults.insert(0, None) args: list[str] = [] for i, arg in enumerate(node.posonlyargs): args.append(self._visit_arg_with_default(arg, defaults[i])) if node.posonlyargs: args.append('/') for i, arg in enumerate(node.args): args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs])) if node.vararg: args.append("*" + self.visit(node.vararg)) if node.kwonlyargs and not node.vararg: args.append('*') for i, arg in enumerate(node.kwonlyargs): args.append(self._visit_arg_with_default(arg, kw_defaults[i])) if node.kwarg: args.append("**" + self.visit(node.kwarg)) return ", ".join(args) def visit_Attribute(self, node: ast.Attribute) -> str: return f"{self.visit(node.value)}.{node.attr}" def visit_BinOp(self, node: ast.BinOp) -> str: # Special case ``**`` to not have surrounding spaces. if isinstance(node.op, ast.Pow): return "".join(map(self.visit, (node.left, node.op, node.right))) return " ".join(self.visit(e) for e in [node.left, node.op, node.right]) def visit_BoolOp(self, node: ast.BoolOp) -> str: op = " %s " % self.visit(node.op) return op.join(self.visit(e) for e in node.values) def visit_Call(self, node: ast.Call) -> str: args = ', '.join([self.visit(e) for e in node.args] + [f"{k.arg}={self.visit(k.value)}" for k in node.keywords]) return f"{self.visit(node.func)}({args})" def visit_Constant(self, node: ast.Constant) -> str: if node.value is Ellipsis: return "..." elif isinstance(node.value, (int, float, complex)): if self.code: return ast.get_source_segment(self.code, node) or repr(node.value) else: return repr(node.value) else: return repr(node.value) def visit_Dict(self, node: ast.Dict) -> str: keys = (self.visit(k) for k in node.keys if k is not None) values = (self.visit(v) for v in node.values) items = (k + ": " + v for k, v in zip(keys, values)) return "{" + ", ".join(items) + "}" def visit_Lambda(self, node: ast.Lambda) -> str: return "lambda %s: ..." % self.visit(node.args) def visit_List(self, node: ast.List) -> str: return "[" + ", ".join(self.visit(e) for e in node.elts) + "]" def visit_Name(self, node: ast.Name) -> str: return node.id def visit_Set(self, node: ast.Set) -> str: return "{" + ", ".join(self.visit(e) for e in node.elts) + "}" def visit_Subscript(self, node: ast.Subscript) -> str: def is_simple_tuple(value: ast.expr) -> bool: return ( isinstance(value, ast.Tuple) and bool(value.elts) and not any(isinstance(elt, ast.Starred) for elt in value.elts) ) if is_simple_tuple(node.slice): elts = ", ".join(self.visit(e) for e in node.slice.elts) # type: ignore[attr-defined] return f"{self.visit(node.value)}[{elts}]" return f"{self.visit(node.value)}[{self.visit(node.slice)}]" def visit_UnaryOp(self, node: ast.UnaryOp) -> str: # UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``, # ``-x``, ``~x``, and ``not x``. Only Not needs a space. if isinstance(node.op, ast.Not): return f"{self.visit(node.op)} {self.visit(node.operand)}" return f"{self.visit(node.op)}{self.visit(node.operand)}" def visit_Tuple(self, node: ast.Tuple) -> str: if len(node.elts) == 0: return "()" elif len(node.elts) == 1: return "(%s,)" % self.visit(node.elts[0]) else: return "(" + ", ".join(self.visit(e) for e in node.elts) + ")" def generic_visit(self, node): raise NotImplementedError('Unable to parse %s object' % type(node).__name__)