summaryrefslogtreecommitdiffstats
path: root/tests/test_pycode_ast.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_pycode_ast.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/tests/test_pycode_ast.py b/tests/test_pycode_ast.py
new file mode 100644
index 0000000..31018ba
--- /dev/null
+++ b/tests/test_pycode_ast.py
@@ -0,0 +1,70 @@
+"""Test pycode.ast"""
+
+import sys
+
+import pytest
+
+from sphinx.pycode import ast
+
+
+@pytest.mark.parametrize('source,expected', [
+ ("a + b", "a + b"), # Add
+ ("a and b", "a and b"), # And
+ ("os.path", "os.path"), # Attribute
+ ("1 * 2", "1 * 2"), # BinOp
+ ("a & b", "a & b"), # BitAnd
+ ("a | b", "a | b"), # BitOr
+ ("a ^ b", "a ^ b"), # BitXor
+ ("a and b and c", "a and b and c"), # BoolOp
+ ("b'bytes'", "b'bytes'"), # Bytes
+ ("object()", "object()"), # Call
+ ("1234", "1234"), # Constant
+ ("{'key1': 'value1', 'key2': 'value2'}",
+ "{'key1': 'value1', 'key2': 'value2'}"), # Dict
+ ("a / b", "a / b"), # Div
+ ("...", "..."), # Ellipsis
+ ("a // b", "a // b"), # FloorDiv
+ ("Tuple[int, int]", "Tuple[int, int]"), # Index, Subscript
+ ("~1", "~1"), # Invert
+ ("lambda x, y: x + y",
+ "lambda x, y: ..."), # Lambda
+ ("[1, 2, 3]", "[1, 2, 3]"), # List
+ ("a << b", "a << b"), # LShift
+ ("a @ b", "a @ b"), # MatMult
+ ("a % b", "a % b"), # Mod
+ ("a * b", "a * b"), # Mult
+ ("sys", "sys"), # Name, NameConstant
+ ("1234", "1234"), # Num
+ ("not a", "not a"), # Not
+ ("a or b", "a or b"), # Or
+ ("a**b", "a**b"), # Pow
+ ("a >> b", "a >> b"), # RShift
+ ("{1, 2, 3}", "{1, 2, 3}"), # Set
+ ("a - b", "a - b"), # Sub
+ ("'str'", "'str'"), # Str
+ ("+a", "+a"), # UAdd
+ ("-1", "-1"), # UnaryOp
+ ("-a", "-a"), # USub
+ ("(1, 2, 3)", "(1, 2, 3)"), # Tuple
+ ("()", "()"), # Tuple (empty)
+ ("(1,)", "(1,)"), # Tuple (single item)
+])
+def test_unparse(source, expected):
+ module = ast.parse(source)
+ assert ast.unparse(module.body[0].value, source) == expected
+
+
+def test_unparse_None():
+ assert ast.unparse(None) is None
+
+
+@pytest.mark.skipif(sys.version_info < (3, 8), reason='python 3.8+ is required.')
+@pytest.mark.parametrize('source,expected', [
+ ("lambda x=0, /, y=1, *args, z, **kwargs: x + y + z",
+ "lambda x=0, /, y=1, *args, z, **kwargs: ..."), # posonlyargs
+ ("0x1234", "0x1234"), # Constant
+ ("1_000_000", "1_000_000"), # Constant
+])
+def test_unparse_py38(source, expected):
+ module = ast.parse(source)
+ assert ast.unparse(module.body[0].value, source) == expected