summaryrefslogtreecommitdiffstats
path: root/tests/test_inputs.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_inputs.py')
-rw-r--r--tests/test_inputs.py225
1 files changed, 225 insertions, 0 deletions
diff --git a/tests/test_inputs.py b/tests/test_inputs.py
new file mode 100644
index 0000000..9247e7b
--- /dev/null
+++ b/tests/test_inputs.py
@@ -0,0 +1,225 @@
+import importlib
+import json
+import math
+import os
+import sys
+from collections import namedtuple
+from types import ModuleType
+from typing import (
+ Any,
+ Dict,
+ List,
+ Set,
+ Tuple,
+)
+
+import pytest
+
+import aristaproto
+from tests.inputs import config as test_input_config
+from tests.mocks import MockChannel
+from tests.util import (
+ find_module,
+ get_directories,
+ get_test_case_json_data,
+ inputs_path,
+)
+
+
+# Force pure-python implementation instead of C++, otherwise imports
+# break things because we can't properly reset the symbol database.
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+from google.protobuf.json_format import Parse
+
+
+class TestCases:
+ def __init__(
+ self,
+ path,
+ services: Set[str],
+ xfail: Set[str],
+ ):
+ _all = set(get_directories(path)) - {"__pycache__"}
+ _services = services
+ _messages = (_all - services) - {"__pycache__"}
+ _messages_with_json = {
+ test for test in _messages if get_test_case_json_data(test)
+ }
+
+ unknown_xfail_tests = xfail - _all
+ if unknown_xfail_tests:
+ raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")
+
+ self.all = self.apply_xfail_marks(_all, xfail)
+ self.services = self.apply_xfail_marks(_services, xfail)
+ self.messages = self.apply_xfail_marks(_messages, xfail)
+ self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail)
+
+ @staticmethod
+ def apply_xfail_marks(test_set: Set[str], xfail: Set[str]):
+ return [
+ pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test
+ for test in test_set
+ ]
+
+
+test_cases = TestCases(
+ path=inputs_path,
+ services=test_input_config.services,
+ xfail=test_input_config.xfail,
+)
+
+plugin_output_package = "tests.output_aristaproto"
+reference_output_package = "tests.output_reference"
+
+TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
+
+
+def module_has_entry_point(module: ModuleType):
+ return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
+
+
+def list_replace_nans(items: List) -> List[Any]:
+ """Replace float("nan") in a list with the string "NaN"
+
+ Parameters
+ ----------
+ items : List
+ List to update
+
+ Returns
+ -------
+ List[Any]
+ Updated list
+ """
+ result = []
+ for item in items:
+ if isinstance(item, list):
+ result.append(list_replace_nans(item))
+ elif isinstance(item, dict):
+ result.append(dict_replace_nans(item))
+ elif isinstance(item, float) and math.isnan(item):
+ result.append(aristaproto.NAN)
+ return result
+
+
+def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
+ """Replace float("nan") in a dictionary with the string "NaN"
+
+ Parameters
+ ----------
+ input_dict : Dict[Any, Any]
+ Dictionary to update
+
+ Returns
+ -------
+ Dict[Any, Any]
+ Updated dictionary
+ """
+ result = {}
+ for key, value in input_dict.items():
+ if isinstance(value, dict):
+ value = dict_replace_nans(value)
+ elif isinstance(value, list):
+ value = list_replace_nans(value)
+ elif isinstance(value, float) and math.isnan(value):
+ value = aristaproto.NAN
+ result[key] = value
+ return result
+
+
+@pytest.fixture
+def test_data(request, reset_sys_path):
+ test_case_name = request.param
+
+ reference_module_root = os.path.join(
+ *reference_output_package.split("."), test_case_name
+ )
+ sys.path.append(reference_module_root)
+
+ plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
+
+ plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
+
+ if not plugin_module_entry_point:
+ raise Exception(
+ f"Test case {repr(test_case_name)} has no entry point. "
+ "Please add a proto message or service called Test and recompile."
+ )
+
+ yield (
+ TestData(
+ plugin_module=plugin_module_entry_point,
+ reference_module=lambda: importlib.import_module(
+ f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
+ ),
+ json_data=get_test_case_json_data(test_case_name),
+ )
+ )
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
+def test_message_can_instantiated(test_data: TestData) -> None:
+ plugin_module, *_ = test_data
+ plugin_module.Test()
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
+def test_message_equality(test_data: TestData) -> None:
+ plugin_module, *_ = test_data
+ message1 = plugin_module.Test()
+ message2 = plugin_module.Test()
+ assert message1 == message2
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
+def test_message_json(repeat, test_data: TestData) -> None:
+ plugin_module, _, json_data = test_data
+
+ for _ in range(repeat):
+ for sample in json_data:
+ if sample.belongs_to(test_input_config.non_symmetrical_json):
+ continue
+
+ message: aristaproto.Message = plugin_module.Test()
+
+ message.from_json(sample.json)
+ message_json = message.to_json(0)
+
+ assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
+ json.loads(sample.json)
+ )
+
+
+@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
+def test_service_can_be_instantiated(test_data: TestData) -> None:
+ test_data.plugin_module.TestStub(MockChannel())
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
+def test_binary_compatibility(repeat, test_data: TestData) -> None:
+ plugin_module, reference_module, json_data = test_data
+
+ for sample in json_data:
+ reference_instance = Parse(sample.json, reference_module().Test())
+ reference_binary_output = reference_instance.SerializeToString()
+
+ for _ in range(repeat):
+ plugin_instance_from_json: aristaproto.Message = (
+ plugin_module.Test().from_json(sample.json)
+ )
+ plugin_instance_from_binary = plugin_module.Test.FromString(
+ reference_binary_output
+ )
+
+ # Generally this can't be relied on, but here we are aiming to match the
+ # existing Python implementation and aren't doing anything tricky.
+ # https://developers.google.com/protocol-buffers/docs/encoding#implications
+ assert bytes(plugin_instance_from_json) == reference_binary_output
+ assert bytes(plugin_instance_from_binary) == reference_binary_output
+
+ assert plugin_instance_from_json == plugin_instance_from_binary
+ assert dict_replace_nans(
+ plugin_instance_from_json.to_dict()
+ ) == dict_replace_nans(plugin_instance_from_binary.to_dict())