diff options
Diffstat (limited to 'anta/models.py')
-rw-r--r-- | anta/models.py | 164 |
1 files changed, 84 insertions, 80 deletions
diff --git a/anta/models.py b/anta/models.py index f963dc0..c44f7e8 100644 --- a/anta/models.py +++ b/anta/models.py @@ -8,10 +8,7 @@ from __future__ import annotations import hashlib import logging import re -import time from abc import ABC, abstractmethod -from copy import deepcopy -from datetime import timedelta from functools import wraps from string import Formatter from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar @@ -19,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar from pydantic import BaseModel, ConfigDict, ValidationError, create_model from anta import GITHUB_SUGGESTION -from anta.custom_types import Revision +from anta.custom_types import REGEXP_EOS_BLACKLIST_CMDS, Revision from anta.logger import anta_log_exception, exc_to_str from anta.result_manager.models import TestResult @@ -35,9 +32,6 @@ F = TypeVar("F", bound=Callable[..., Any]) # This would imply overhead to define classes # https://stackoverflow.com/questions/74103528/type-hinting-an-instance-of-a-nested-class -# TODO: make this configurable - with an env var maybe? -BLACKLIST_REGEX = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"] - logger = logging.getLogger(__name__) @@ -46,19 +40,8 @@ class AntaParamsBaseModel(BaseModel): model_config = ConfigDict(extra="forbid") - if not TYPE_CHECKING: - # Following pydantic declaration and keeping __getattr__ only when TYPE_CHECKING is false. - # Disabling 1 Dynamically typed expressions (typing.Any) are disallowed in `__getattr__ - # ruff: noqa: ANN401 - def __getattr__(self, item: str) -> Any: - """For AntaParams if we try to access an attribute that is not present We want it to be None.""" - try: - return super().__getattr__(item) - except AttributeError: - return None - -class AntaTemplate(BaseModel): +class AntaTemplate: """Class to define a command template as Python f-string. Can render a command from parameters. @@ -70,14 +53,42 @@ class AntaTemplate(BaseModel): revision: Revision of the command. Valid values are 1 to 99. Revision has precedence over version. ofmt: eAPI output - json or text. use_cache: Enable or disable caching for this AntaTemplate if the AntaDevice supports it. - """ - template: str - version: Literal[1, "latest"] = "latest" - revision: Revision | None = None - ofmt: Literal["json", "text"] = "json" - use_cache: bool = True + # pylint: disable=too-few-public-methods + + def __init__( # noqa: PLR0913 + self, + template: str, + version: Literal[1, "latest"] = "latest", + revision: Revision | None = None, + ofmt: Literal["json", "text"] = "json", + *, + use_cache: bool = True, + ) -> None: + # pylint: disable=too-many-arguments + self.template = template + self.version = version + self.revision = revision + self.ofmt = ofmt + self.use_cache = use_cache + + # Create a AntaTemplateParams model to elegantly store AntaTemplate variables + field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname] + # Extracting the type from the params based on the expected field_names from the template + fields: dict[str, Any] = {key: (Any, ...) for key in field_names} + self.params_schema = create_model( + "AntaParams", + __base__=AntaParamsBaseModel, + **fields, + ) + + def __repr__(self) -> str: + """Return the representation of the class. + + Copying pydantic model style, excluding `params_schema` + """ + return " ".join(f"{a}={v!r}" for a, v in vars(self).items() if a != "params_schema") def render(self, **params: str | int | bool) -> AntaCommand: """Render an AntaCommand from an AntaTemplate instance. @@ -90,34 +101,28 @@ class AntaTemplate(BaseModel): Returns ------- - command: The rendered AntaCommand. - This AntaCommand instance have a template attribute that references this - AntaTemplate instance. + The rendered AntaCommand. + This AntaCommand instance have a template attribute that references this + AntaTemplate instance. + Raises + ------ + AntaTemplateRenderError + If a parameter is missing to render the AntaTemplate instance. """ - # Create params schema on the fly - field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname] - # Extracting the type from the params based on the expected field_names from the template - fields: dict[str, Any] = {key: (type(params.get(key)), ...) for key in field_names} - # Accepting ParamsSchema as non lowercase variable - ParamsSchema = create_model( # noqa: N806 - "ParamsSchema", - __base__=AntaParamsBaseModel, - **fields, - ) - try: - return AntaCommand( - command=self.template.format(**params), - ofmt=self.ofmt, - version=self.version, - revision=self.revision, - template=self, - params=ParamsSchema(**params), - use_cache=self.use_cache, - ) - except KeyError as e: + command = self.template.format(**params) + except (KeyError, SyntaxError) as e: raise AntaTemplateRenderError(self, e.args[0]) from e + return AntaCommand( + command=command, + ofmt=self.ofmt, + version=self.version, + revision=self.revision, + template=self, + params=self.params_schema(**params), + use_cache=self.use_cache, + ) class AntaCommand(BaseModel): @@ -148,6 +153,8 @@ class AntaCommand(BaseModel): """ + model_config = ConfigDict(arbitrary_types_allowed=True) + command: str version: Literal[1, "latest"] = "latest" revision: Revision | None = None @@ -273,14 +280,13 @@ class AntaTest(ABC): vrf: str = "default" def render(self, template: AntaTemplate) -> list[AntaCommand]: - return [template.render({"dst": host.dst, "src": host.src, "vrf": host.vrf}) for host in self.inputs.hosts] + return [template.render(dst=host.dst, src=host.src, vrf=host.vrf) for host in self.inputs.hosts] @AntaTest.anta_test def test(self) -> None: failures = [] for command in self.instance_commands: - if command.params and ("src" and "dst") in command.params: - src, dst = command.params["src"], command.params["dst"] + src, dst = command.params.src, command.params.dst if "2 received" not in command.json_output["messages"][0]: failures.append((str(src), str(dst))) if not failures: @@ -288,13 +294,14 @@ class AntaTest(ABC): else: self.result.is_failure(f"Connectivity test failed for the following source-destination pairs: {failures}") ``` - Attributes: + + Attributes + ---------- device: AntaDevice instance on which this test is run inputs: AntaTest.Input instance carrying the test inputs instance_commands: List of AntaCommand instances of this test result: TestResult instance representing the result of this test logger: Python logger for this test instance - """ # Mandatory class attributes @@ -322,9 +329,10 @@ class AntaTest(ABC): description: "Test with overwritten description" custom_field: "Test run by John Doe" ``` - Attributes: - result_overwrite: Define fields to overwrite in the TestResult object + Attributes + ---------- + result_overwrite: Define fields to overwrite in the TestResult object """ model_config = ConfigDict(extra="forbid") @@ -360,7 +368,6 @@ class AntaTest(ABC): Attributes ---------- tags: Tag of devices on which to run the test. - """ model_config = ConfigDict(extra="forbid") @@ -380,9 +387,8 @@ class AntaTest(ABC): inputs: dictionary of attributes used to instantiate the AntaTest.Input instance eos_data: Populate outputs of the test commands instead of collecting from devices. This list must have the same length and order than the `instance_commands` instance attribute. - """ - self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}") + self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}") self.device: AntaDevice = device self.inputs: AntaTest.Input self.instance_commands: list[AntaCommand] = [] @@ -411,7 +417,7 @@ class AntaTest(ABC): elif isinstance(inputs, dict): self.inputs = self.Input(**inputs) except ValidationError as e: - message = f"{self.__module__}.{self.__class__.__name__}: Inputs are not valid\n{e}" + message = f"{self.module}.{self.name}: Inputs are not valid\n{e}" self.logger.error(message) self.result.is_error(message=message) return @@ -434,7 +440,7 @@ class AntaTest(ABC): if self.__class__.commands: for cmd in self.__class__.commands: if isinstance(cmd, AntaCommand): - self.instance_commands.append(deepcopy(cmd)) + self.instance_commands.append(cmd.model_copy()) elif isinstance(cmd, AntaTemplate): try: self.instance_commands.extend(self.render(cmd)) @@ -448,7 +454,7 @@ class AntaTest(ABC): # render() is user-defined code. # We need to catch everything if we want the AntaTest object # to live until the reporting - message = f"Exception in {self.__module__}.{self.__class__.__name__}.render()" + message = f"Exception in {self.module}.{self.__class__.__name__}.render()" anta_log_exception(e, message, self.logger) self.result.is_error(message=f"{message}: {exc_to_str(e)}") return @@ -477,13 +483,18 @@ class AntaTest(ABC): raise NotImplementedError(msg) @property + def module(self) -> str: + """Return the Python module in which this AntaTest class is defined.""" + return self.__module__ + + @property def collected(self) -> bool: - """Returns True if all commands for this test have been collected.""" + """Return True if all commands for this test have been collected.""" return all(command.collected for command in self.instance_commands) @property def failed_commands(self) -> list[AntaCommand]: - """Returns a list of all the commands that have failed.""" + """Return a list of all the commands that have failed.""" return [command for command in self.instance_commands if command.error] def render(self, template: AntaTemplate) -> list[AntaCommand]: @@ -493,7 +504,7 @@ class AntaTest(ABC): no AntaTemplate for this test. """ _ = template - msg = f"AntaTemplate are provided but render() method has not been implemented for {self.__module__}.{self.name}" + msg = f"AntaTemplate are provided but render() method has not been implemented for {self.module}.{self.__class__.__name__}" raise NotImplementedError(msg) @property @@ -501,12 +512,12 @@ class AntaTest(ABC): """Check if CLI commands contain a blocked keyword.""" state = False for command in self.instance_commands: - for pattern in BLACKLIST_REGEX: + for pattern in REGEXP_EOS_BLACKLIST_CMDS: if re.match(pattern, command.command): self.logger.error( "Command <%s> is blocked for security reason matching %s", command.command, - BLACKLIST_REGEX, + REGEXP_EOS_BLACKLIST_CMDS, ) self.result.is_error(f"<{command.command}> is blocked for security reason") state = True @@ -516,7 +527,7 @@ class AntaTest(ABC): """Collect outputs of all commands of this test class from the device of this test instance.""" try: if self.blocked is False: - await self.device.collect_commands(self.instance_commands) + await self.device.collect_commands(self.instance_commands, collection_id=self.name) except Exception as e: # pylint: disable=broad-exception-caught # device._collect() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -557,12 +568,6 @@ class AntaTest(ABC): result: TestResult instance attribute populated with error status if any """ - - def format_td(seconds: float, digits: int = 3) -> str: - isec, fsec = divmod(round(seconds * 10**digits), 10**digits) - return f"{timedelta(seconds=isec)}.{fsec:0{digits}.0f}" - - start_time = time.time() if self.result.result != "unset": return self.result @@ -575,6 +580,7 @@ class AntaTest(ABC): if not self.collected: await self.collect() if self.result.result != "unset": + AntaTest.update_progress() return self.result if cmds := self.failed_commands: @@ -583,8 +589,9 @@ class AntaTest(ABC): msg = f"Test {self.name} has been skipped because it is not supported on {self.device.hw_model}: {GITHUB_SUGGESTION}" self.logger.warning(msg) self.result.is_skipped("\n".join(unsupported_commands)) - return self.result - self.result.is_error(message="\n".join([f"{c.command} has failed: {', '.join(c.errors)}" for c in cmds])) + else: + self.result.is_error(message="\n".join([f"{c.command} has failed: {', '.join(c.errors)}" for c in cmds])) + AntaTest.update_progress() return self.result try: @@ -597,10 +604,7 @@ class AntaTest(ABC): anta_log_exception(e, message, self.logger) self.result.is_error(message=exc_to_str(e)) - test_duration = time.time() - start_time - msg = f"Executing test {self.name} on device {self.device.name} took {format_td(test_duration)}" - self.logger.debug(msg) - + # TODO: find a correct way to time test execution AntaTest.update_progress() return self.result |