summaryrefslogtreecommitdiffstats
path: root/anta/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'anta/models.py')
-rw-r--r--anta/models.py164
1 files changed, 84 insertions, 80 deletions
diff --git a/anta/models.py b/anta/models.py
index f963dc0..c44f7e8 100644
--- a/anta/models.py
+++ b/anta/models.py
@@ -8,10 +8,7 @@ from __future__ import annotations
import hashlib
import logging
import re
-import time
from abc import ABC, abstractmethod
-from copy import deepcopy
-from datetime import timedelta
from functools import wraps
from string import Formatter
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar
@@ -19,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
from anta import GITHUB_SUGGESTION
-from anta.custom_types import Revision
+from anta.custom_types import REGEXP_EOS_BLACKLIST_CMDS, Revision
from anta.logger import anta_log_exception, exc_to_str
from anta.result_manager.models import TestResult
@@ -35,9 +32,6 @@ F = TypeVar("F", bound=Callable[..., Any])
# This would imply overhead to define classes
# https://stackoverflow.com/questions/74103528/type-hinting-an-instance-of-a-nested-class
-# TODO: make this configurable - with an env var maybe?
-BLACKLIST_REGEX = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"]
-
logger = logging.getLogger(__name__)
@@ -46,19 +40,8 @@ class AntaParamsBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")
- if not TYPE_CHECKING:
- # Following pydantic declaration and keeping __getattr__ only when TYPE_CHECKING is false.
- # Disabling 1 Dynamically typed expressions (typing.Any) are disallowed in `__getattr__
- # ruff: noqa: ANN401
- def __getattr__(self, item: str) -> Any:
- """For AntaParams if we try to access an attribute that is not present We want it to be None."""
- try:
- return super().__getattr__(item)
- except AttributeError:
- return None
-
-class AntaTemplate(BaseModel):
+class AntaTemplate:
"""Class to define a command template as Python f-string.
Can render a command from parameters.
@@ -70,14 +53,42 @@ class AntaTemplate(BaseModel):
revision: Revision of the command. Valid values are 1 to 99. Revision has precedence over version.
ofmt: eAPI output - json or text.
use_cache: Enable or disable caching for this AntaTemplate if the AntaDevice supports it.
-
"""
- template: str
- version: Literal[1, "latest"] = "latest"
- revision: Revision | None = None
- ofmt: Literal["json", "text"] = "json"
- use_cache: bool = True
+ # pylint: disable=too-few-public-methods
+
+ def __init__( # noqa: PLR0913
+ self,
+ template: str,
+ version: Literal[1, "latest"] = "latest",
+ revision: Revision | None = None,
+ ofmt: Literal["json", "text"] = "json",
+ *,
+ use_cache: bool = True,
+ ) -> None:
+ # pylint: disable=too-many-arguments
+ self.template = template
+ self.version = version
+ self.revision = revision
+ self.ofmt = ofmt
+ self.use_cache = use_cache
+
+ # Create a AntaTemplateParams model to elegantly store AntaTemplate variables
+ field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
+ # Extracting the type from the params based on the expected field_names from the template
+ fields: dict[str, Any] = {key: (Any, ...) for key in field_names}
+ self.params_schema = create_model(
+ "AntaParams",
+ __base__=AntaParamsBaseModel,
+ **fields,
+ )
+
+ def __repr__(self) -> str:
+ """Return the representation of the class.
+
+ Copying pydantic model style, excluding `params_schema`
+ """
+ return " ".join(f"{a}={v!r}" for a, v in vars(self).items() if a != "params_schema")
def render(self, **params: str | int | bool) -> AntaCommand:
"""Render an AntaCommand from an AntaTemplate instance.
@@ -90,34 +101,28 @@ class AntaTemplate(BaseModel):
Returns
-------
- command: The rendered AntaCommand.
- This AntaCommand instance have a template attribute that references this
- AntaTemplate instance.
+ The rendered AntaCommand.
+ This AntaCommand instance have a template attribute that references this
+ AntaTemplate instance.
+ Raises
+ ------
+ AntaTemplateRenderError
+ If a parameter is missing to render the AntaTemplate instance.
"""
- # Create params schema on the fly
- field_names = [fname for _, fname, _, _ in Formatter().parse(self.template) if fname]
- # Extracting the type from the params based on the expected field_names from the template
- fields: dict[str, Any] = {key: (type(params.get(key)), ...) for key in field_names}
- # Accepting ParamsSchema as non lowercase variable
- ParamsSchema = create_model( # noqa: N806
- "ParamsSchema",
- __base__=AntaParamsBaseModel,
- **fields,
- )
-
try:
- return AntaCommand(
- command=self.template.format(**params),
- ofmt=self.ofmt,
- version=self.version,
- revision=self.revision,
- template=self,
- params=ParamsSchema(**params),
- use_cache=self.use_cache,
- )
- except KeyError as e:
+ command = self.template.format(**params)
+ except (KeyError, SyntaxError) as e:
raise AntaTemplateRenderError(self, e.args[0]) from e
+ return AntaCommand(
+ command=command,
+ ofmt=self.ofmt,
+ version=self.version,
+ revision=self.revision,
+ template=self,
+ params=self.params_schema(**params),
+ use_cache=self.use_cache,
+ )
class AntaCommand(BaseModel):
@@ -148,6 +153,8 @@ class AntaCommand(BaseModel):
"""
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
command: str
version: Literal[1, "latest"] = "latest"
revision: Revision | None = None
@@ -273,14 +280,13 @@ class AntaTest(ABC):
vrf: str = "default"
def render(self, template: AntaTemplate) -> list[AntaCommand]:
- return [template.render({"dst": host.dst, "src": host.src, "vrf": host.vrf}) for host in self.inputs.hosts]
+ return [template.render(dst=host.dst, src=host.src, vrf=host.vrf) for host in self.inputs.hosts]
@AntaTest.anta_test
def test(self) -> None:
failures = []
for command in self.instance_commands:
- if command.params and ("src" and "dst") in command.params:
- src, dst = command.params["src"], command.params["dst"]
+ src, dst = command.params.src, command.params.dst
if "2 received" not in command.json_output["messages"][0]:
failures.append((str(src), str(dst)))
if not failures:
@@ -288,13 +294,14 @@ class AntaTest(ABC):
else:
self.result.is_failure(f"Connectivity test failed for the following source-destination pairs: {failures}")
```
- Attributes:
+
+ Attributes
+ ----------
device: AntaDevice instance on which this test is run
inputs: AntaTest.Input instance carrying the test inputs
instance_commands: List of AntaCommand instances of this test
result: TestResult instance representing the result of this test
logger: Python logger for this test instance
-
"""
# Mandatory class attributes
@@ -322,9 +329,10 @@ class AntaTest(ABC):
description: "Test with overwritten description"
custom_field: "Test run by John Doe"
```
- Attributes:
- result_overwrite: Define fields to overwrite in the TestResult object
+ Attributes
+ ----------
+ result_overwrite: Define fields to overwrite in the TestResult object
"""
model_config = ConfigDict(extra="forbid")
@@ -360,7 +368,6 @@ class AntaTest(ABC):
Attributes
----------
tags: Tag of devices on which to run the test.
-
"""
model_config = ConfigDict(extra="forbid")
@@ -380,9 +387,8 @@ class AntaTest(ABC):
inputs: dictionary of attributes used to instantiate the AntaTest.Input instance
eos_data: Populate outputs of the test commands instead of collecting from devices.
This list must have the same length and order than the `instance_commands` instance attribute.
-
"""
- self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}")
+ self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}")
self.device: AntaDevice = device
self.inputs: AntaTest.Input
self.instance_commands: list[AntaCommand] = []
@@ -411,7 +417,7 @@ class AntaTest(ABC):
elif isinstance(inputs, dict):
self.inputs = self.Input(**inputs)
except ValidationError as e:
- message = f"{self.__module__}.{self.__class__.__name__}: Inputs are not valid\n{e}"
+ message = f"{self.module}.{self.name}: Inputs are not valid\n{e}"
self.logger.error(message)
self.result.is_error(message=message)
return
@@ -434,7 +440,7 @@ class AntaTest(ABC):
if self.__class__.commands:
for cmd in self.__class__.commands:
if isinstance(cmd, AntaCommand):
- self.instance_commands.append(deepcopy(cmd))
+ self.instance_commands.append(cmd.model_copy())
elif isinstance(cmd, AntaTemplate):
try:
self.instance_commands.extend(self.render(cmd))
@@ -448,7 +454,7 @@ class AntaTest(ABC):
# render() is user-defined code.
# We need to catch everything if we want the AntaTest object
# to live until the reporting
- message = f"Exception in {self.__module__}.{self.__class__.__name__}.render()"
+ message = f"Exception in {self.module}.{self.__class__.__name__}.render()"
anta_log_exception(e, message, self.logger)
self.result.is_error(message=f"{message}: {exc_to_str(e)}")
return
@@ -477,13 +483,18 @@ class AntaTest(ABC):
raise NotImplementedError(msg)
@property
+ def module(self) -> str:
+ """Return the Python module in which this AntaTest class is defined."""
+ return self.__module__
+
+ @property
def collected(self) -> bool:
- """Returns True if all commands for this test have been collected."""
+ """Return True if all commands for this test have been collected."""
return all(command.collected for command in self.instance_commands)
@property
def failed_commands(self) -> list[AntaCommand]:
- """Returns a list of all the commands that have failed."""
+ """Return a list of all the commands that have failed."""
return [command for command in self.instance_commands if command.error]
def render(self, template: AntaTemplate) -> list[AntaCommand]:
@@ -493,7 +504,7 @@ class AntaTest(ABC):
no AntaTemplate for this test.
"""
_ = template
- msg = f"AntaTemplate are provided but render() method has not been implemented for {self.__module__}.{self.name}"
+ msg = f"AntaTemplate are provided but render() method has not been implemented for {self.module}.{self.__class__.__name__}"
raise NotImplementedError(msg)
@property
@@ -501,12 +512,12 @@ class AntaTest(ABC):
"""Check if CLI commands contain a blocked keyword."""
state = False
for command in self.instance_commands:
- for pattern in BLACKLIST_REGEX:
+ for pattern in REGEXP_EOS_BLACKLIST_CMDS:
if re.match(pattern, command.command):
self.logger.error(
"Command <%s> is blocked for security reason matching %s",
command.command,
- BLACKLIST_REGEX,
+ REGEXP_EOS_BLACKLIST_CMDS,
)
self.result.is_error(f"<{command.command}> is blocked for security reason")
state = True
@@ -516,7 +527,7 @@ class AntaTest(ABC):
"""Collect outputs of all commands of this test class from the device of this test instance."""
try:
if self.blocked is False:
- await self.device.collect_commands(self.instance_commands)
+ await self.device.collect_commands(self.instance_commands, collection_id=self.name)
except Exception as e: # pylint: disable=broad-exception-caught
# device._collect() is user-defined code.
# We need to catch everything if we want the AntaTest object
@@ -557,12 +568,6 @@ class AntaTest(ABC):
result: TestResult instance attribute populated with error status if any
"""
-
- def format_td(seconds: float, digits: int = 3) -> str:
- isec, fsec = divmod(round(seconds * 10**digits), 10**digits)
- return f"{timedelta(seconds=isec)}.{fsec:0{digits}.0f}"
-
- start_time = time.time()
if self.result.result != "unset":
return self.result
@@ -575,6 +580,7 @@ class AntaTest(ABC):
if not self.collected:
await self.collect()
if self.result.result != "unset":
+ AntaTest.update_progress()
return self.result
if cmds := self.failed_commands:
@@ -583,8 +589,9 @@ class AntaTest(ABC):
msg = f"Test {self.name} has been skipped because it is not supported on {self.device.hw_model}: {GITHUB_SUGGESTION}"
self.logger.warning(msg)
self.result.is_skipped("\n".join(unsupported_commands))
- return self.result
- self.result.is_error(message="\n".join([f"{c.command} has failed: {', '.join(c.errors)}" for c in cmds]))
+ else:
+ self.result.is_error(message="\n".join([f"{c.command} has failed: {', '.join(c.errors)}" for c in cmds]))
+ AntaTest.update_progress()
return self.result
try:
@@ -597,10 +604,7 @@ class AntaTest(ABC):
anta_log_exception(e, message, self.logger)
self.result.is_error(message=exc_to_str(e))
- test_duration = time.time() - start_time
- msg = f"Executing test {self.name} on device {self.device.name} took {format_td(test_duration)}"
- self.logger.debug(msg)
-
+ # TODO: find a correct way to time test execution
AntaTest.update_progress()
return self.result