diff options
Diffstat (limited to 'tests/test_inputs.py')
-rw-r--r-- | tests/test_inputs.py | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000..9247e7b --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,225 @@ +import importlib +import json +import math +import os +import sys +from collections import namedtuple +from types import ModuleType +from typing import ( + Any, + Dict, + List, + Set, + Tuple, +) + +import pytest + +import aristaproto +from tests.inputs import config as test_input_config +from tests.mocks import MockChannel +from tests.util import ( + find_module, + get_directories, + get_test_case_json_data, + inputs_path, +) + + +# Force pure-python implementation instead of C++, otherwise imports +# break things because we can't properly reset the symbol database. +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +from google.protobuf.json_format import Parse + + +class TestCases: + def __init__( + self, + path, + services: Set[str], + xfail: Set[str], + ): + _all = set(get_directories(path)) - {"__pycache__"} + _services = services + _messages = (_all - services) - {"__pycache__"} + _messages_with_json = { + test for test in _messages if get_test_case_json_data(test) + } + + unknown_xfail_tests = xfail - _all + if unknown_xfail_tests: + raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}") + + self.all = self.apply_xfail_marks(_all, xfail) + self.services = self.apply_xfail_marks(_services, xfail) + self.messages = self.apply_xfail_marks(_messages, xfail) + self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail) + + @staticmethod + def apply_xfail_marks(test_set: Set[str], xfail: Set[str]): + return [ + pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test + for test in test_set + ] + + +test_cases = TestCases( + path=inputs_path, + services=test_input_config.services, + xfail=test_input_config.xfail, +) + +plugin_output_package = "tests.output_aristaproto" +reference_output_package = "tests.output_reference" + +TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"]) + + +def module_has_entry_point(module: ModuleType): + return any(hasattr(module, attr) for attr in ["Test", "TestStub"]) + + +def list_replace_nans(items: List) -> List[Any]: + """Replace float("nan") in a list with the string "NaN" + + Parameters + ---------- + items : List + List to update + + Returns + ------- + List[Any] + Updated list + """ + result = [] + for item in items: + if isinstance(item, list): + result.append(list_replace_nans(item)) + elif isinstance(item, dict): + result.append(dict_replace_nans(item)) + elif isinstance(item, float) and math.isnan(item): + result.append(aristaproto.NAN) + return result + + +def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]: + """Replace float("nan") in a dictionary with the string "NaN" + + Parameters + ---------- + input_dict : Dict[Any, Any] + Dictionary to update + + Returns + ------- + Dict[Any, Any] + Updated dictionary + """ + result = {} + for key, value in input_dict.items(): + if isinstance(value, dict): + value = dict_replace_nans(value) + elif isinstance(value, list): + value = list_replace_nans(value) + elif isinstance(value, float) and math.isnan(value): + value = aristaproto.NAN + result[key] = value + return result + + +@pytest.fixture +def test_data(request, reset_sys_path): + test_case_name = request.param + + reference_module_root = os.path.join( + *reference_output_package.split("."), test_case_name + ) + sys.path.append(reference_module_root) + + plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") + + plugin_module_entry_point = find_module(plugin_module, module_has_entry_point) + + if not plugin_module_entry_point: + raise Exception( + f"Test case {repr(test_case_name)} has no entry point. " + "Please add a proto message or service called Test and recompile." + ) + + yield ( + TestData( + plugin_module=plugin_module_entry_point, + reference_module=lambda: importlib.import_module( + f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + ), + json_data=get_test_case_json_data(test_case_name), + ) + ) + + +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) +def test_message_can_instantiated(test_data: TestData) -> None: + plugin_module, *_ = test_data + plugin_module.Test() + + +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) +def test_message_equality(test_data: TestData) -> None: + plugin_module, *_ = test_data + message1 = plugin_module.Test() + message2 = plugin_module.Test() + assert message1 == message2 + + +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) +def test_message_json(repeat, test_data: TestData) -> None: + plugin_module, _, json_data = test_data + + for _ in range(repeat): + for sample in json_data: + if sample.belongs_to(test_input_config.non_symmetrical_json): + continue + + message: aristaproto.Message = plugin_module.Test() + + message.from_json(sample.json) + message_json = message.to_json(0) + + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( + json.loads(sample.json) + ) + + +@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) +def test_service_can_be_instantiated(test_data: TestData) -> None: + test_data.plugin_module.TestStub(MockChannel()) + + +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) +def test_binary_compatibility(repeat, test_data: TestData) -> None: + plugin_module, reference_module, json_data = test_data + + for sample in json_data: + reference_instance = Parse(sample.json, reference_module().Test()) + reference_binary_output = reference_instance.SerializeToString() + + for _ in range(repeat): + plugin_instance_from_json: aristaproto.Message = ( + plugin_module.Test().from_json(sample.json) + ) + plugin_instance_from_binary = plugin_module.Test.FromString( + reference_binary_output + ) + + # Generally this can't be relied on, but here we are aiming to match the + # existing Python implementation and aren't doing anything tricky. + # https://developers.google.com/protocol-buffers/docs/encoding#implications + assert bytes(plugin_instance_from_json) == reference_binary_output + assert bytes(plugin_instance_from_binary) == reference_binary_output + + assert plugin_instance_from_json == plugin_instance_from_binary + assert dict_replace_nans( + plugin_instance_from_json.to_dict() + ) == dict_replace_nans(plugin_instance_from_binary.to_dict()) |