summaryrefslogtreecommitdiffstats
path: root/anta/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'anta/device.py')
-rw-r--r--anta/device.py319
1 files changed, 171 insertions, 148 deletions
diff --git a/anta/device.py b/anta/device.py
index d9060c9..d517b8f 100644
--- a/anta/device.py
+++ b/anta/device.py
@@ -1,65 +1,75 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
-"""
-ANTA Device Abstraction Module
-"""
+"""ANTA Device Abstraction Module."""
+
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
-from pathlib import Path
-from typing import Any, Iterator, Literal, Optional, Union
+from typing import TYPE_CHECKING, Any, Literal
import asyncssh
+import httpcore
from aiocache import Cache
from aiocache.plugins import HitMissRatioPlugin
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
-from httpx import ConnectError, HTTPError
+from httpx import ConnectError, HTTPError, TimeoutException
from anta import __DEBUG__, aioeapi
+from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaCommand
-from anta.tools.misc import exc_to_str
+
+if TYPE_CHECKING:
+ from collections.abc import Iterator
+ from pathlib import Path
logger = logging.getLogger(__name__)
+# Do not load the default keypairs multiple times due to a performance issue introduced in cryptography 37.0
+# https://github.com/pyca/cryptography/issues/7236#issuecomment-1131908472
+CLIENT_KEYS = asyncssh.public_key.load_default_keypairs()
+
class AntaDevice(ABC):
- """
- Abstract class representing a device in ANTA.
+ """Abstract class representing a device in ANTA.
+
An implementation of this class must override the abstract coroutines `_collect()` and
`refresh()`.
- Attributes:
+ Attributes
+ ----------
name: Device name
- is_online: True if the device IP is reachable and a port can be open
- established: True if remote command execution succeeds
- hw_model: Hardware model of the device
- tags: List of tags for this device
- cache: In-memory cache from aiocache library for this device (None if cache is disabled)
- cache_locks: Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled
+ is_online: True if the device IP is reachable and a port can be open.
+ established: True if remote command execution succeeds.
+ hw_model: Hardware model of the device.
+ tags: Tags for this device.
+ cache: In-memory cache from aiocache library for this device (None if cache is disabled).
+ cache_locks: Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled.
+
"""
- def __init__(self, name: str, tags: Optional[list[str]] = None, disable_cache: bool = False) -> None:
- """
- Constructor of AntaDevice
+ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bool = False) -> None:
+ """Initialize an AntaDevice.
Args:
- name: Device name
- tags: List of tags for this device
- disable_cache: Disable caching for all commands for this device. Defaults to False.
+ ----
+ name: Device name.
+ tags: Tags for this device.
+ disable_cache: Disable caching for all commands for this device.
+
"""
self.name: str = name
- self.hw_model: Optional[str] = None
- self.tags: list[str] = tags if tags is not None else []
+ self.hw_model: str | None = None
+ self.tags: set[str] = tags if tags is not None else set()
# A device always has its own name as tag
- self.tags.append(self.name)
+ self.tags.add(self.name)
self.is_online: bool = False
self.established: bool = False
- self.cache: Optional[Cache] = None
- self.cache_locks: Optional[defaultdict[str, asyncio.Lock]] = None
+ self.cache: Cache | None = None
+ self.cache_locks: defaultdict[str, asyncio.Lock] | None = None
# Initialize cache if not disabled
if not disable_cache:
@@ -68,34 +78,24 @@ class AntaDevice(ABC):
@property
@abstractmethod
def _keys(self) -> tuple[Any, ...]:
- """
- Read-only property to implement hashing and equality for AntaDevice classes.
- """
+ """Read-only property to implement hashing and equality for AntaDevice classes."""
def __eq__(self, other: object) -> bool:
- """
- Implement equality for AntaDevice objects.
- """
+ """Implement equality for AntaDevice objects."""
return self._keys == other._keys if isinstance(other, self.__class__) else False
def __hash__(self) -> int:
- """
- Implement hashing for AntaDevice objects.
- """
+ """Implement hashing for AntaDevice objects."""
return hash(self._keys)
def _init_cache(self) -> None:
- """
- Initialize cache for the device, can be overriden by subclasses to manipulate how it works
- """
+ """Initialize cache for the device, can be overridden by subclasses to manipulate how it works."""
self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
self.cache_locks = defaultdict(asyncio.Lock)
@property
def cache_statistics(self) -> dict[str, Any] | None:
- """
- Returns the device cache statistics for logging purposes
- """
+ """Returns the device cache statistics for logging purposes."""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None:
@@ -104,9 +104,9 @@ class AntaDevice(ABC):
return None
def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
- """
- Implements Rich Repr Protocol
- https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol
+ """Implement Rich Repr Protocol.
+
+ https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol.
"""
yield "name", self.name
yield "tags", self.tags
@@ -117,8 +117,8 @@ class AntaDevice(ABC):
@abstractmethod
async def _collect(self, command: AntaCommand) -> None:
- """
- Collect device command output.
+ """Collect device command output.
+
This abstract coroutine can be used to implement any command collection method
for a device in ANTA.
@@ -130,12 +130,13 @@ class AntaDevice(ABC):
`AntaCommand` object passed as argument would be `None` in this case.
Args:
+ ----
command: the command to collect
+
"""
async def collect(self, command: AntaCommand) -> None:
- """
- Collects the output for a specified command.
+ """Collect the output for a specified command.
When caching is activated on both the device and the command,
this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet,
@@ -146,7 +147,9 @@ class AntaDevice(ABC):
via the private `_collect` method without interacting with the cache.
Args:
+ ----
command (AntaCommand): The command to process.
+
"""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
@@ -155,7 +158,7 @@ class AntaDevice(ABC):
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member
if cached_output is not None:
- logger.debug(f"Cache hit for {command.command} on {self.name}")
+ logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
await self._collect(command=command)
@@ -164,26 +167,18 @@ class AntaDevice(ABC):
await self._collect(command=command)
async def collect_commands(self, commands: list[AntaCommand]) -> None:
- """
- Collect multiple commands.
+ """Collect multiple commands.
Args:
+ ----
commands: the commands to collect
+
"""
await asyncio.gather(*(self.collect(command=command) for command in commands))
- def supports(self, command: AntaCommand) -> bool:
- """Returns True if the command is supported on the device hardware platform, False otherwise."""
- unsupported = any("not supported on this hardware platform" in e for e in command.errors)
- logger.debug(command)
- if unsupported:
- logger.debug(f"{command.command} is not supported on {self.hw_model}")
- return not unsupported
-
@abstractmethod
async def refresh(self) -> None:
- """
- Update attributes of an AntaDevice instance.
+ """Update attributes of an AntaDevice instance.
This coroutine must update the following attributes of AntaDevice:
- `is_online`: When the device IP is reachable and a port can be open
@@ -192,63 +187,71 @@ class AntaDevice(ABC):
"""
async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None:
- """
- Copy files to and from the device, usually through SCP.
+ """Copy files to and from the device, usually through SCP.
+
It is not mandatory to implement this for a valid AntaDevice subclass.
Args:
+ ----
sources: List of files to copy to or from the device.
destination: Local or remote destination when copying the files. Can be a folder.
direction: Defines if this coroutine copies files to or from the device.
+
"""
- raise NotImplementedError(f"copy() method has not been implemented in {self.__class__.__name__} definition")
+ _ = (sources, destination, direction)
+ msg = f"copy() method has not been implemented in {self.__class__.__name__} definition"
+ raise NotImplementedError(msg)
class AsyncEOSDevice(AntaDevice):
- """
- Implementation of AntaDevice for EOS using aio-eapi.
+ """Implementation of AntaDevice for EOS using aio-eapi.
- Attributes:
+ Attributes
+ ----------
name: Device name
is_online: True if the device IP is reachable and a port can be open
established: True if remote command execution succeeds
hw_model: Hardware model of the device
- tags: List of tags for this device
+ tags: Tags for this device
+
"""
- def __init__( # pylint: disable=R0913
+ # pylint: disable=R0913
+ def __init__(
self,
host: str,
username: str,
password: str,
- name: Optional[str] = None,
+ name: str | None = None,
+ enable_password: str | None = None,
+ port: int | None = None,
+ ssh_port: int | None = 22,
+ tags: set[str] | None = None,
+ timeout: float | None = None,
+ proto: Literal["http", "https"] = "https",
+ *,
enable: bool = False,
- enable_password: Optional[str] = None,
- port: Optional[int] = None,
- ssh_port: Optional[int] = 22,
- tags: Optional[list[str]] = None,
- timeout: Optional[float] = None,
insecure: bool = False,
- proto: Literal["http", "https"] = "https",
disable_cache: bool = False,
) -> None:
- """
- Constructor of AsyncEOSDevice
+ """Instantiate an AsyncEOSDevice.
Args:
- host: Device FQDN or IP
- username: Username to connect to eAPI and SSH
- password: Password to connect to eAPI and SSH
- name: Device name
- enable: Device needs privileged access
- enable_password: Password used to gain privileged access on EOS
+ ----
+ host: Device FQDN or IP.
+ username: Username to connect to eAPI and SSH.
+ password: Password to connect to eAPI and SSH.
+ name: Device name.
+ enable: Collect commands using privileged mode.
+ enable_password: Password used to gain privileged access on EOS.
port: eAPI port. Defaults to 80 is proto is 'http' or 443 if proto is 'https'.
- ssh_port: SSH port
- tags: List of tags for this device
- timeout: Timeout value in seconds for outgoing connections. Default to 10 secs.
- insecure: Disable SSH Host Key validation
- proto: eAPI protocol. Value can be 'http' or 'https'
- disable_cache: Disable caching for all commands for this device. Defaults to False.
+ ssh_port: SSH port.
+ tags: Tags for this device.
+ timeout: Timeout value in seconds for outgoing API calls.
+ insecure: Disable SSH Host Key validation.
+ proto: eAPI protocol. Value can be 'http' or 'https'.
+ disable_cache: Disable caching for all commands for this device.
+
"""
if host is None:
message = "'host' is required to create an AsyncEOSDevice"
@@ -256,7 +259,7 @@ class AsyncEOSDevice(AntaDevice):
raise ValueError(message)
if name is None:
name = f"{host}{f':{port}' if port else ''}"
- super().__init__(name, tags, disable_cache)
+ super().__init__(name, tags, disable_cache=disable_cache)
if username is None:
message = f"'username' is required to instantiate device '{self.name}'"
logger.error(message)
@@ -271,12 +274,14 @@ class AsyncEOSDevice(AntaDevice):
ssh_params: dict[str, Any] = {}
if insecure:
ssh_params["known_hosts"] = None
- self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions(host=host, port=ssh_port, username=username, password=password, **ssh_params)
+ self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions(
+ host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params
+ )
def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
- """
- Implements Rich Repr Protocol
- https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol
+ """Implement Rich Repr Protocol.
+
+ https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol.
"""
yield from super().__rich_repr__()
yield ("host", self._session.host)
@@ -286,107 +291,123 @@ class AsyncEOSDevice(AntaDevice):
yield ("insecure", self._ssh_opts.known_hosts is None)
if __DEBUG__:
_ssh_opts = vars(self._ssh_opts).copy()
- PASSWORD_VALUE = "<removed>"
- _ssh_opts["password"] = PASSWORD_VALUE
- _ssh_opts["kwargs"]["password"] = PASSWORD_VALUE
+ removed_pw = "<removed>"
+ _ssh_opts["password"] = removed_pw
+ _ssh_opts["kwargs"]["password"] = removed_pw
yield ("_session", vars(self._session))
yield ("_ssh_opts", _ssh_opts)
@property
def _keys(self) -> tuple[Any, ...]:
- """
- Two AsyncEOSDevice objects are equal if the hostname and the port are the same.
+ """Two AsyncEOSDevice objects are equal if the hostname and the port are the same.
+
This covers the use case of port forwarding when the host is localhost and the devices have different ports.
"""
return (self._session.host, self._session.port)
- async def _collect(self, command: AntaCommand) -> None:
- """
- Collect device command output from EOS using aio-eapi.
+ async def _collect(self, command: AntaCommand) -> None: # noqa: C901 function is too complex - because of many required except blocks
+ """Collect device command output from EOS using aio-eapi.
Supports outformat `json` and `text` as output structure.
Gain privileged access using the `enable_password` attribute
of the `AntaDevice` instance if populated.
Args:
- command: the command to collect
+ ----
+ command: the AntaCommand to collect.
"""
- commands = []
+ commands: list[dict[str, Any]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
- }
+ },
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
- if command.revision:
- commands.append({"cmd": command.command, "revision": command.revision})
- else:
- commands.append({"cmd": command.command})
+ commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any]] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
)
+ # Do not keep response of 'enable' command
+ command.output = response[-1]
except aioeapi.EapiCommandError as e:
+ # This block catches exceptions related to EOS issuing an error.
command.errors = e.errors
- if self.supports(command):
- message = f"Command '{command.command}' failed on {self.name}"
- logger.error(message)
- except (HTTPError, ConnectError) as e:
- command.errors = [str(e)]
- message = f"Cannot connect to device {self.name}"
- logger.error(message)
- else:
- # selecting only our command output
- command.output = response[-1]
- logger.debug(f"{self.name}: {command}")
+ if command.requires_privileges:
+ logger.error(
+ "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name
+ )
+ if command.supported:
+ logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors)
+ else:
+ logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model)
+ except TimeoutException as e:
+ # This block catches Timeout exceptions.
+ command.errors = [exc_to_str(e)]
+ timeouts = self._session.timeout.as_dict()
+ logger.error(
+ "%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
+ exc_to_str(e),
+ self.name,
+ timeouts["connect"],
+ timeouts["read"],
+ timeouts["write"],
+ timeouts["pool"],
+ )
+ except (ConnectError, OSError) as e:
+ # This block catches OSError and socket issues related exceptions.
+ command.errors = [exc_to_str(e)]
+ if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member
+ if isinstance(os_error.__cause__, OSError):
+ os_error = os_error.__cause__
+ logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
+ else:
+ anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
+ except HTTPError as e:
+ # This block catches most of the httpx Exceptions and logs a general message.
+ command.errors = [exc_to_str(e)]
+ anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
+ logger.debug("%s: %s", self.name, command)
async def refresh(self) -> None:
- """
- Update attributes of an AsyncEOSDevice instance.
+ """Update attributes of an AsyncEOSDevice instance.
This coroutine must update the following attributes of AsyncEOSDevice:
- is_online: When a device IP is reachable and a port can be open
- established: When a command execution succeeds
- hw_model: The hardware model of the device
"""
- logger.debug(f"Refreshing device {self.name}")
+ logger.debug("Refreshing device %s", self.name)
self.is_online = await self._session.check_connection()
if self.is_online:
- COMMAND: str = "show version"
- HW_MODEL_KEY: str = "modelName"
- try:
- response = await self._session.cli(command=COMMAND)
- except aioeapi.EapiCommandError as e:
- logger.warning(f"Cannot get hardware information from device {self.name}: {e.errmsg}")
-
- except (HTTPError, ConnectError) as e:
- logger.warning(f"Cannot get hardware information from device {self.name}: {exc_to_str(e)}")
-
+ show_version = AntaCommand(command="show version")
+ await self._collect(show_version)
+ if not show_version.collected:
+ logger.warning("Cannot get hardware information from device %s", self.name)
else:
- if HW_MODEL_KEY in response:
- self.hw_model = response[HW_MODEL_KEY]
- else:
- logger.warning(f"Cannot get hardware information from device {self.name}: cannot parse '{COMMAND}'")
-
+ self.hw_model = show_version.json_output.get("modelName", None)
+ if self.hw_model is None:
+ logger.critical("Cannot parse 'show version' returned by device %s", self.name)
else:
- logger.warning(f"Could not connect to device {self.name}: cannot open eAPI port")
+ logger.warning("Could not connect to device %s: cannot open eAPI port", self.name)
self.established = bool(self.is_online and self.hw_model)
async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None:
- """
- Copy files to and from the device using asyncssh.scp().
+ """Copy files to and from the device using asyncssh.scp().
Args:
+ ----
sources: List of files to copy to or from the device.
destination: Local or remote destination when copying the files. Can be a folder.
direction: Defines if this coroutine copies files to or from the device.
+
"""
async with asyncssh.connect(
host=self._ssh_opts.host,
@@ -396,22 +417,24 @@ class AsyncEOSDevice(AntaDevice):
local_addr=self._ssh_opts.local_addr,
options=self._ssh_opts,
) as conn:
- src: Union[list[tuple[SSHClientConnection, Path]], list[Path]]
- dst: Union[tuple[SSHClientConnection, Path], Path]
+ src: list[tuple[SSHClientConnection, Path]] | list[Path]
+ dst: tuple[SSHClientConnection, Path] | Path
if direction == "from":
src = [(conn, file) for file in sources]
dst = destination
for file in sources:
- logger.info(f"Copying '{file}' from device {self.name} to '{destination}' locally")
+ message = f"Copying '{file}' from device {self.name} to '{destination}' locally"
+ logger.info(message)
elif direction == "to":
src = sources
dst = conn, destination
for file in src:
- logger.info(f"Copying '{file}' to device {self.name} to '{destination}' remotely")
+ message = f"Copying '{file}' to device {self.name} to '{destination}' remotely"
+ logger.info(message)
else:
- logger.critical(f"'direction' argument to copy() fonction is invalid: {direction}")
+ logger.critical("'direction' argument to copy() function is invalid: %s", direction)
return
await asyncssh.scp(src, dst)