diff options
Diffstat (limited to '')
-rw-r--r-- | tests/benchmark/utils.py | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py new file mode 100644 index 0000000..1017cfe --- /dev/null +++ b/tests/benchmark/utils.py @@ -0,0 +1,164 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Utils for the ANTA benchmark tests.""" + +from __future__ import annotations + +import asyncio +import copy +import importlib +import json +import pkgutil +from typing import TYPE_CHECKING, Any + +import httpx + +from anta.catalog import AntaCatalog, AntaTestDefinition +from anta.models import AntaCommand, AntaTest + +if TYPE_CHECKING: + from collections.abc import Generator + from types import ModuleType + + from anta.device import AntaDevice + + +async def collect(self: AntaTest) -> None: + """Patched anta.models.AntaTest.collect() method. + + When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test. + We use this unit test case name in the eAPI request ID. + """ + if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None: + msg = f"The custom_field input is not present for test {self.name}" + raise RuntimeError(msg) + await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}") + + +async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None: + """Patched anta.device.AntaDevice.collect_commands() method. + + For the same reason as above, we inject the command index of the test to the eAPI request ID. + """ + await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands))) + + +class AntaMockEnvironment: # pylint: disable=too-few-public-methods + """Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance. + + Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog. + + Each module in `tests.units.anta_tests` has a `DATA` constant. + The `DATA` structure is a list of dictionaries used to parametrize the test. The list elements have the following keys: + - `name` (str): Test name as displayed by Pytest. + - `test` (AntaTest): An AntaTest subclass imported in the test module - e.g. VerifyUptime. + - `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test. + - `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`. + + The keys of `eos_data_catalog` is the tuple (DATA['test'], DATA['name']). The values are `eos_data`. + """ + + def __init__(self) -> None: + self._catalog, self.eos_data_catalog = self._generate_catalog() + self.tests_count = len(self._catalog.tests) + + @property + def catalog(self) -> AntaCatalog: + """AntaMockEnvironment object will always return a new AntaCatalog object based on the initial parsing. + + This is because AntaCatalog objects store indexes when tests are run and we want a new object each time a test is run. + """ + return copy.deepcopy(self._catalog) + + def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]: + """Generate the `catalog` and `eos_data_catalog` attributes.""" + + def import_test_modules() -> Generator[ModuleType, None, None]: + """Yield all test modules from the given package.""" + package = importlib.import_module("tests.units.anta_tests") + prefix = package.__name__ + "." + for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix): + if not is_pkg and module_name.split(".")[-1].startswith("test_"): + module = importlib.import_module(module_name) + if hasattr(module, "DATA"): + yield module + + test_definitions = [] + eos_data_catalog = {} + for module in import_test_modules(): + for test_data in module.DATA: + test = test_data["test"] + result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"]) + if test_data["inputs"] is None: + inputs = test.Input(result_overwrite=result_overwrite) + else: + inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite) + test_definition = AntaTestDefinition( + test=test, + inputs=inputs, + ) + eos_data_catalog[(test.__name__, test_data["name"])] = test_data["eos_data"] + test_definitions.append(test_definition) + + return (AntaCatalog(tests=test_definitions), eos_data_catalog) + + def eapi_response(self, request: httpx.Request) -> httpx.Response: + """Mock eAPI response. + + If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`, + the function will return the eos_data from the unit test case. + + Otherwise, it will mock 'show version' command or raise an Exception. + """ + words_count = 3 + + def parse_req_id(req_id: str) -> tuple[str, str, int] | None: + """Parse the patched request ID from the eAPI request.""" + req_id = req_id.removeprefix("ANTA-").rpartition("-")[0] + words = req_id.split(":", words_count) + if len(words) == words_count: + test_name, unit_test_name, command_index = words + return test_name, unit_test_name, int(command_index) + return None + + jsonrpc = json.loads(request.content) + assert jsonrpc["method"] == "runCmds" + commands = jsonrpc["params"]["cmds"] + ofmt = jsonrpc["params"]["format"] + req_id: str = jsonrpc["id"] + result = None + + # Extract the test name, unit test name, and command index from the request ID + if (words := parse_req_id(req_id)) is not None: + test_name, unit_test_name, idx = words + + # This should never happen, but better be safe than sorry + if (test_name, unit_test_name) not in self.eos_data_catalog: + msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found" + raise RuntimeError(msg) + + eos_data = self.eos_data_catalog[(test_name, unit_test_name)] + + # This could happen if the unit test data is not correctly defined + if idx >= len(eos_data): + msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data" + raise RuntimeError(msg) + result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx] + elif {"cmd": "show version"} in commands and ofmt == "json": + # Mock 'show version' request performed during inventory refresh. + result = { + "modelName": "pytest", + } + + if result is not None: + return httpx.Response( + status_code=200, + json={ + "jsonrpc": "2.0", + "id": req_id, + "result": [result], + }, + ) + msg = f"The following eAPI Request has not been mocked: {jsonrpc}" + raise NotImplementedError(msg) |