1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Utils for the ANTA benchmark tests."""
from __future__ import annotations
import asyncio
import copy
import importlib
import json
import pkgutil
from typing import TYPE_CHECKING, Any
import httpx
from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.models import AntaCommand, AntaTest
if TYPE_CHECKING:
from collections.abc import Generator
from types import ModuleType
from anta.device import AntaDevice
async def collect(self: AntaTest) -> None:
"""Patched anta.models.AntaTest.collect() method.
When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test.
We use this unit test case name in the eAPI request ID.
"""
if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None:
msg = f"The custom_field input is not present for test {self.name}"
raise RuntimeError(msg)
await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}")
async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None:
"""Patched anta.device.AntaDevice.collect_commands() method.
For the same reason as above, we inject the command index of the test to the eAPI request ID.
"""
await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands)))
class AntaMockEnvironment: # pylint: disable=too-few-public-methods
"""Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance.
Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog.
Each module in `tests.units.anta_tests` has a `DATA` constant.
The `DATA` structure is a list of dictionaries used to parametrize the test. The list elements have the following keys:
- `name` (str): Test name as displayed by Pytest.
- `test` (AntaTest): An AntaTest subclass imported in the test module - e.g. VerifyUptime.
- `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test.
- `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`.
The keys of `eos_data_catalog` is the tuple (DATA['test'], DATA['name']). The values are `eos_data`.
"""
def __init__(self) -> None:
self._catalog, self.eos_data_catalog = self._generate_catalog()
self.tests_count = len(self._catalog.tests)
@property
def catalog(self) -> AntaCatalog:
"""AntaMockEnvironment object will always return a new AntaCatalog object based on the initial parsing.
This is because AntaCatalog objects store indexes when tests are run and we want a new object each time a test is run.
"""
return copy.deepcopy(self._catalog)
def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]:
"""Generate the `catalog` and `eos_data_catalog` attributes."""
def import_test_modules() -> Generator[ModuleType, None, None]:
"""Yield all test modules from the given package."""
package = importlib.import_module("tests.units.anta_tests")
prefix = package.__name__ + "."
for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix):
if not is_pkg and module_name.split(".")[-1].startswith("test_"):
module = importlib.import_module(module_name)
if hasattr(module, "DATA"):
yield module
test_definitions = []
eos_data_catalog = {}
for module in import_test_modules():
for test_data in module.DATA:
test = test_data["test"]
result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"])
if test_data["inputs"] is None:
inputs = test.Input(result_overwrite=result_overwrite)
else:
inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite)
test_definition = AntaTestDefinition(
test=test,
inputs=inputs,
)
eos_data_catalog[(test.__name__, test_data["name"])] = test_data["eos_data"]
test_definitions.append(test_definition)
return (AntaCatalog(tests=test_definitions), eos_data_catalog)
def eapi_response(self, request: httpx.Request) -> httpx.Response:
"""Mock eAPI response.
If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`,
the function will return the eos_data from the unit test case.
Otherwise, it will mock 'show version' command or raise an Exception.
"""
words_count = 3
def parse_req_id(req_id: str) -> tuple[str, str, int] | None:
"""Parse the patched request ID from the eAPI request."""
req_id = req_id.removeprefix("ANTA-").rpartition("-")[0]
words = req_id.split(":", words_count)
if len(words) == words_count:
test_name, unit_test_name, command_index = words
return test_name, unit_test_name, int(command_index)
return None
jsonrpc = json.loads(request.content)
assert jsonrpc["method"] == "runCmds"
commands = jsonrpc["params"]["cmds"]
ofmt = jsonrpc["params"]["format"]
req_id: str = jsonrpc["id"]
result = None
# Extract the test name, unit test name, and command index from the request ID
if (words := parse_req_id(req_id)) is not None:
test_name, unit_test_name, idx = words
# This should never happen, but better be safe than sorry
if (test_name, unit_test_name) not in self.eos_data_catalog:
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found"
raise RuntimeError(msg)
eos_data = self.eos_data_catalog[(test_name, unit_test_name)]
# This could happen if the unit test data is not correctly defined
if idx >= len(eos_data):
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data"
raise RuntimeError(msg)
result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx]
elif {"cmd": "show version"} in commands and ofmt == "json":
# Mock 'show version' request performed during inventory refresh.
result = {
"modelName": "pytest",
}
if result is not None:
return httpx.Response(
status_code=200,
json={
"jsonrpc": "2.0",
"id": req_id,
"result": [result],
},
)
msg = f"The following eAPI Request has not been mocked: {jsonrpc}"
raise NotImplementedError(msg)
|