summaryrefslogtreecommitdiffstats
path: root/tests/util.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/util.py169
1 files changed, 169 insertions, 0 deletions
diff --git a/tests/util.py b/tests/util.py
new file mode 100644
index 0000000..2ba7cab
--- /dev/null
+++ b/tests/util.py
@@ -0,0 +1,169 @@
+import asyncio
+import atexit
+import importlib
+import os
+import platform
+import sys
+import tempfile
+from dataclasses import dataclass
+from pathlib import Path
+from types import ModuleType
+from typing import (
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+root_path = Path(__file__).resolve().parent
+inputs_path = root_path.joinpath("inputs")
+output_path_reference = root_path.joinpath("output_reference")
+output_path_aristaproto = root_path.joinpath("output_aristaproto")
+output_path_aristaproto_pydantic = root_path.joinpath("output_aristaproto_pydantic")
+
+
+def get_files(path, suffix: str) -> Generator[str, None, None]:
+ for r, dirs, files in os.walk(path):
+ for filename in [f for f in files if f.endswith(suffix)]:
+ yield os.path.join(r, filename)
+
+
+def get_directories(path):
+ for root, directories, files in os.walk(path):
+ yield from directories
+
+
+async def protoc(
+ path: Union[str, Path],
+ output_dir: Union[str, Path],
+ reference: bool = False,
+ pydantic_dataclasses: bool = False,
+):
+ path: Path = Path(path).resolve()
+ output_dir: Path = Path(output_dir).resolve()
+ python_out_option: str = "python_aristaproto_out" if not reference else "python_out"
+
+ if pydantic_dataclasses:
+ plugin_path = Path("src/aristaproto/plugin/main.py")
+
+ if "Win" in platform.system():
+ with tempfile.NamedTemporaryFile(
+ "w", encoding="UTF-8", suffix=".bat", delete=False
+ ) as tf:
+ # See https://stackoverflow.com/a/42622705
+ tf.writelines(
+ [
+ "@echo off",
+ f"\nchdir {os.getcwd()}",
+ f"\n{sys.executable} -u {plugin_path.as_posix()}",
+ ]
+ )
+
+ tf.flush()
+
+ plugin_path = Path(tf.name)
+ atexit.register(os.remove, plugin_path)
+
+ command = [
+ sys.executable,
+ "-m",
+ "grpc.tools.protoc",
+ f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
+ "--experimental_allow_proto3_optional",
+ "--custom_opt=pydantic_dataclasses",
+ f"--proto_path={path.as_posix()}",
+ f"--custom_out={output_dir.as_posix()}",
+ *[p.as_posix() for p in path.glob("*.proto")],
+ ]
+ else:
+ command = [
+ sys.executable,
+ "-m",
+ "grpc.tools.protoc",
+ f"--proto_path={path.as_posix()}",
+ f"--{python_out_option}={output_dir.as_posix()}",
+ *[p.as_posix() for p in path.glob("*.proto")],
+ ]
+ proc = await asyncio.create_subprocess_exec(
+ *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
+ )
+ stdout, stderr = await proc.communicate()
+ return stdout, stderr, proc.returncode
+
+
+@dataclass
+class TestCaseJsonFile:
+ json: str
+ test_name: str
+ file_name: str
+
+ def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
+ return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
+
+
+def get_test_case_json_data(
+ test_case_name: str, *json_file_names: str
+) -> List[TestCaseJsonFile]:
+ """
+ :return:
+ A list of all files found in "{inputs_path}/test_case_name" with names matching
+ f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
+ json_file_names
+ """
+ test_case_dir = inputs_path.joinpath(test_case_name)
+ possible_file_paths = [
+ *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
+ test_case_dir.joinpath(f"{test_case_name}.json"),
+ *test_case_dir.glob(f"{test_case_name}_*.json"),
+ ]
+
+ result = []
+ for test_data_file_path in possible_file_paths:
+ if not test_data_file_path.exists():
+ continue
+ with test_data_file_path.open("r") as fh:
+ result.append(
+ TestCaseJsonFile(
+ fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
+ )
+ )
+
+ return result
+
+
+def find_module(
+ module: ModuleType, predicate: Callable[[ModuleType], bool]
+) -> Optional[ModuleType]:
+ """
+ Recursively search module tree for a module that matches the search predicate.
+ Assumes that the submodules are directories containing __init__.py.
+
+ Example:
+
+ # find module inside foo that contains Test
+ import foo
+ test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
+ """
+ if predicate(module):
+ return module
+
+ module_path = Path(*module.__path__)
+
+ for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
+ if sub == module_path:
+ continue
+ sub_module_path = sub.relative_to(module_path)
+ sub_module_name = ".".join(sub_module_path.parts)
+
+ sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
+
+ if predicate(sub_module):
+ return sub_module
+
+ return None