diff options
Diffstat (limited to 'anta/device.py')
-rw-r--r-- | anta/device.py | 319 |
1 files changed, 171 insertions, 148 deletions
diff --git a/anta/device.py b/anta/device.py index d9060c9..d517b8f 100644 --- a/anta/device.py +++ b/anta/device.py @@ -1,65 +1,75 @@ # Copyright (c) 2023-2024 Arista Networks, Inc. # Use of this source code is governed by the Apache License 2.0 # that can be found in the LICENSE file. -""" -ANTA Device Abstraction Module -""" +"""ANTA Device Abstraction Module.""" + from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod from collections import defaultdict -from pathlib import Path -from typing import Any, Iterator, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import asyncssh +import httpcore from aiocache import Cache from aiocache.plugins import HitMissRatioPlugin from asyncssh import SSHClientConnection, SSHClientConnectionOptions -from httpx import ConnectError, HTTPError +from httpx import ConnectError, HTTPError, TimeoutException from anta import __DEBUG__, aioeapi +from anta.logger import anta_log_exception, exc_to_str from anta.models import AntaCommand -from anta.tools.misc import exc_to_str + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path logger = logging.getLogger(__name__) +# Do not load the default keypairs multiple times due to a performance issue introduced in cryptography 37.0 +# https://github.com/pyca/cryptography/issues/7236#issuecomment-1131908472 +CLIENT_KEYS = asyncssh.public_key.load_default_keypairs() + class AntaDevice(ABC): - """ - Abstract class representing a device in ANTA. + """Abstract class representing a device in ANTA. + An implementation of this class must override the abstract coroutines `_collect()` and `refresh()`. - Attributes: + Attributes + ---------- name: Device name - is_online: True if the device IP is reachable and a port can be open - established: True if remote command execution succeeds - hw_model: Hardware model of the device - tags: List of tags for this device - cache: In-memory cache from aiocache library for this device (None if cache is disabled) - cache_locks: Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled + is_online: True if the device IP is reachable and a port can be open. + established: True if remote command execution succeeds. + hw_model: Hardware model of the device. + tags: Tags for this device. + cache: In-memory cache from aiocache library for this device (None if cache is disabled). + cache_locks: Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled. + """ - def __init__(self, name: str, tags: Optional[list[str]] = None, disable_cache: bool = False) -> None: - """ - Constructor of AntaDevice + def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bool = False) -> None: + """Initialize an AntaDevice. Args: - name: Device name - tags: List of tags for this device - disable_cache: Disable caching for all commands for this device. Defaults to False. + ---- + name: Device name. + tags: Tags for this device. + disable_cache: Disable caching for all commands for this device. + """ self.name: str = name - self.hw_model: Optional[str] = None - self.tags: list[str] = tags if tags is not None else [] + self.hw_model: str | None = None + self.tags: set[str] = tags if tags is not None else set() # A device always has its own name as tag - self.tags.append(self.name) + self.tags.add(self.name) self.is_online: bool = False self.established: bool = False - self.cache: Optional[Cache] = None - self.cache_locks: Optional[defaultdict[str, asyncio.Lock]] = None + self.cache: Cache | None = None + self.cache_locks: defaultdict[str, asyncio.Lock] | None = None # Initialize cache if not disabled if not disable_cache: @@ -68,34 +78,24 @@ class AntaDevice(ABC): @property @abstractmethod def _keys(self) -> tuple[Any, ...]: - """ - Read-only property to implement hashing and equality for AntaDevice classes. - """ + """Read-only property to implement hashing and equality for AntaDevice classes.""" def __eq__(self, other: object) -> bool: - """ - Implement equality for AntaDevice objects. - """ + """Implement equality for AntaDevice objects.""" return self._keys == other._keys if isinstance(other, self.__class__) else False def __hash__(self) -> int: - """ - Implement hashing for AntaDevice objects. - """ + """Implement hashing for AntaDevice objects.""" return hash(self._keys) def _init_cache(self) -> None: - """ - Initialize cache for the device, can be overriden by subclasses to manipulate how it works - """ + """Initialize cache for the device, can be overridden by subclasses to manipulate how it works.""" self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()]) self.cache_locks = defaultdict(asyncio.Lock) @property def cache_statistics(self) -> dict[str, Any] | None: - """ - Returns the device cache statistics for logging purposes - """ + """Returns the device cache statistics for logging purposes.""" # Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough # https://github.com/pylint-dev/pylint/issues/7258 if self.cache is not None: @@ -104,9 +104,9 @@ class AntaDevice(ABC): return None def __rich_repr__(self) -> Iterator[tuple[str, Any]]: - """ - Implements Rich Repr Protocol - https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol + """Implement Rich Repr Protocol. + + https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol. """ yield "name", self.name yield "tags", self.tags @@ -117,8 +117,8 @@ class AntaDevice(ABC): @abstractmethod async def _collect(self, command: AntaCommand) -> None: - """ - Collect device command output. + """Collect device command output. + This abstract coroutine can be used to implement any command collection method for a device in ANTA. @@ -130,12 +130,13 @@ class AntaDevice(ABC): `AntaCommand` object passed as argument would be `None` in this case. Args: + ---- command: the command to collect + """ async def collect(self, command: AntaCommand) -> None: - """ - Collects the output for a specified command. + """Collect the output for a specified command. When caching is activated on both the device and the command, this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet, @@ -146,7 +147,9 @@ class AntaDevice(ABC): via the private `_collect` method without interacting with the cache. Args: + ---- command (AntaCommand): The command to process. + """ # Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough # https://github.com/pylint-dev/pylint/issues/7258 @@ -155,7 +158,7 @@ class AntaDevice(ABC): cached_output = await self.cache.get(command.uid) # pylint: disable=no-member if cached_output is not None: - logger.debug(f"Cache hit for {command.command} on {self.name}") + logger.debug("Cache hit for %s on %s", command.command, self.name) command.output = cached_output else: await self._collect(command=command) @@ -164,26 +167,18 @@ class AntaDevice(ABC): await self._collect(command=command) async def collect_commands(self, commands: list[AntaCommand]) -> None: - """ - Collect multiple commands. + """Collect multiple commands. Args: + ---- commands: the commands to collect + """ await asyncio.gather(*(self.collect(command=command) for command in commands)) - def supports(self, command: AntaCommand) -> bool: - """Returns True if the command is supported on the device hardware platform, False otherwise.""" - unsupported = any("not supported on this hardware platform" in e for e in command.errors) - logger.debug(command) - if unsupported: - logger.debug(f"{command.command} is not supported on {self.hw_model}") - return not unsupported - @abstractmethod async def refresh(self) -> None: - """ - Update attributes of an AntaDevice instance. + """Update attributes of an AntaDevice instance. This coroutine must update the following attributes of AntaDevice: - `is_online`: When the device IP is reachable and a port can be open @@ -192,63 +187,71 @@ class AntaDevice(ABC): """ async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None: - """ - Copy files to and from the device, usually through SCP. + """Copy files to and from the device, usually through SCP. + It is not mandatory to implement this for a valid AntaDevice subclass. Args: + ---- sources: List of files to copy to or from the device. destination: Local or remote destination when copying the files. Can be a folder. direction: Defines if this coroutine copies files to or from the device. + """ - raise NotImplementedError(f"copy() method has not been implemented in {self.__class__.__name__} definition") + _ = (sources, destination, direction) + msg = f"copy() method has not been implemented in {self.__class__.__name__} definition" + raise NotImplementedError(msg) class AsyncEOSDevice(AntaDevice): - """ - Implementation of AntaDevice for EOS using aio-eapi. + """Implementation of AntaDevice for EOS using aio-eapi. - Attributes: + Attributes + ---------- name: Device name is_online: True if the device IP is reachable and a port can be open established: True if remote command execution succeeds hw_model: Hardware model of the device - tags: List of tags for this device + tags: Tags for this device + """ - def __init__( # pylint: disable=R0913 + # pylint: disable=R0913 + def __init__( self, host: str, username: str, password: str, - name: Optional[str] = None, + name: str | None = None, + enable_password: str | None = None, + port: int | None = None, + ssh_port: int | None = 22, + tags: set[str] | None = None, + timeout: float | None = None, + proto: Literal["http", "https"] = "https", + *, enable: bool = False, - enable_password: Optional[str] = None, - port: Optional[int] = None, - ssh_port: Optional[int] = 22, - tags: Optional[list[str]] = None, - timeout: Optional[float] = None, insecure: bool = False, - proto: Literal["http", "https"] = "https", disable_cache: bool = False, ) -> None: - """ - Constructor of AsyncEOSDevice + """Instantiate an AsyncEOSDevice. Args: - host: Device FQDN or IP - username: Username to connect to eAPI and SSH - password: Password to connect to eAPI and SSH - name: Device name - enable: Device needs privileged access - enable_password: Password used to gain privileged access on EOS + ---- + host: Device FQDN or IP. + username: Username to connect to eAPI and SSH. + password: Password to connect to eAPI and SSH. + name: Device name. + enable: Collect commands using privileged mode. + enable_password: Password used to gain privileged access on EOS. port: eAPI port. Defaults to 80 is proto is 'http' or 443 if proto is 'https'. - ssh_port: SSH port - tags: List of tags for this device - timeout: Timeout value in seconds for outgoing connections. Default to 10 secs. - insecure: Disable SSH Host Key validation - proto: eAPI protocol. Value can be 'http' or 'https' - disable_cache: Disable caching for all commands for this device. Defaults to False. + ssh_port: SSH port. + tags: Tags for this device. + timeout: Timeout value in seconds for outgoing API calls. + insecure: Disable SSH Host Key validation. + proto: eAPI protocol. Value can be 'http' or 'https'. + disable_cache: Disable caching for all commands for this device. + """ if host is None: message = "'host' is required to create an AsyncEOSDevice" @@ -256,7 +259,7 @@ class AsyncEOSDevice(AntaDevice): raise ValueError(message) if name is None: name = f"{host}{f':{port}' if port else ''}" - super().__init__(name, tags, disable_cache) + super().__init__(name, tags, disable_cache=disable_cache) if username is None: message = f"'username' is required to instantiate device '{self.name}'" logger.error(message) @@ -271,12 +274,14 @@ class AsyncEOSDevice(AntaDevice): ssh_params: dict[str, Any] = {} if insecure: ssh_params["known_hosts"] = None - self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions(host=host, port=ssh_port, username=username, password=password, **ssh_params) + self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions( + host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params + ) def __rich_repr__(self) -> Iterator[tuple[str, Any]]: - """ - Implements Rich Repr Protocol - https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol + """Implement Rich Repr Protocol. + + https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol. """ yield from super().__rich_repr__() yield ("host", self._session.host) @@ -286,107 +291,123 @@ class AsyncEOSDevice(AntaDevice): yield ("insecure", self._ssh_opts.known_hosts is None) if __DEBUG__: _ssh_opts = vars(self._ssh_opts).copy() - PASSWORD_VALUE = "<removed>" - _ssh_opts["password"] = PASSWORD_VALUE - _ssh_opts["kwargs"]["password"] = PASSWORD_VALUE + removed_pw = "<removed>" + _ssh_opts["password"] = removed_pw + _ssh_opts["kwargs"]["password"] = removed_pw yield ("_session", vars(self._session)) yield ("_ssh_opts", _ssh_opts) @property def _keys(self) -> tuple[Any, ...]: - """ - Two AsyncEOSDevice objects are equal if the hostname and the port are the same. + """Two AsyncEOSDevice objects are equal if the hostname and the port are the same. + This covers the use case of port forwarding when the host is localhost and the devices have different ports. """ return (self._session.host, self._session.port) - async def _collect(self, command: AntaCommand) -> None: - """ - Collect device command output from EOS using aio-eapi. + async def _collect(self, command: AntaCommand) -> None: # noqa: C901 function is too complex - because of many required except blocks + """Collect device command output from EOS using aio-eapi. Supports outformat `json` and `text` as output structure. Gain privileged access using the `enable_password` attribute of the `AntaDevice` instance if populated. Args: - command: the command to collect + ---- + command: the AntaCommand to collect. """ - commands = [] + commands: list[dict[str, Any]] = [] if self.enable and self._enable_password is not None: commands.append( { "cmd": "enable", "input": str(self._enable_password), - } + }, ) elif self.enable: # No password commands.append({"cmd": "enable"}) - if command.revision: - commands.append({"cmd": command.command, "revision": command.revision}) - else: - commands.append({"cmd": command.command}) + commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}] try: response: list[dict[str, Any]] = await self._session.cli( commands=commands, ofmt=command.ofmt, version=command.version, ) + # Do not keep response of 'enable' command + command.output = response[-1] except aioeapi.EapiCommandError as e: + # This block catches exceptions related to EOS issuing an error. command.errors = e.errors - if self.supports(command): - message = f"Command '{command.command}' failed on {self.name}" - logger.error(message) - except (HTTPError, ConnectError) as e: - command.errors = [str(e)] - message = f"Cannot connect to device {self.name}" - logger.error(message) - else: - # selecting only our command output - command.output = response[-1] - logger.debug(f"{self.name}: {command}") + if command.requires_privileges: + logger.error( + "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name + ) + if command.supported: + logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors) + else: + logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model) + except TimeoutException as e: + # This block catches Timeout exceptions. + command.errors = [exc_to_str(e)] + timeouts = self._session.timeout.as_dict() + logger.error( + "%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s", + exc_to_str(e), + self.name, + timeouts["connect"], + timeouts["read"], + timeouts["write"], + timeouts["pool"], + ) + except (ConnectError, OSError) as e: + # This block catches OSError and socket issues related exceptions. + command.errors = [exc_to_str(e)] + if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member + if isinstance(os_error.__cause__, OSError): + os_error = os_error.__cause__ + logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error) + else: + anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) + except HTTPError as e: + # This block catches most of the httpx Exceptions and logs a general message. + command.errors = [exc_to_str(e)] + anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) + logger.debug("%s: %s", self.name, command) async def refresh(self) -> None: - """ - Update attributes of an AsyncEOSDevice instance. + """Update attributes of an AsyncEOSDevice instance. This coroutine must update the following attributes of AsyncEOSDevice: - is_online: When a device IP is reachable and a port can be open - established: When a command execution succeeds - hw_model: The hardware model of the device """ - logger.debug(f"Refreshing device {self.name}") + logger.debug("Refreshing device %s", self.name) self.is_online = await self._session.check_connection() if self.is_online: - COMMAND: str = "show version" - HW_MODEL_KEY: str = "modelName" - try: - response = await self._session.cli(command=COMMAND) - except aioeapi.EapiCommandError as e: - logger.warning(f"Cannot get hardware information from device {self.name}: {e.errmsg}") - - except (HTTPError, ConnectError) as e: - logger.warning(f"Cannot get hardware information from device {self.name}: {exc_to_str(e)}") - + show_version = AntaCommand(command="show version") + await self._collect(show_version) + if not show_version.collected: + logger.warning("Cannot get hardware information from device %s", self.name) else: - if HW_MODEL_KEY in response: - self.hw_model = response[HW_MODEL_KEY] - else: - logger.warning(f"Cannot get hardware information from device {self.name}: cannot parse '{COMMAND}'") - + self.hw_model = show_version.json_output.get("modelName", None) + if self.hw_model is None: + logger.critical("Cannot parse 'show version' returned by device %s", self.name) else: - logger.warning(f"Could not connect to device {self.name}: cannot open eAPI port") + logger.warning("Could not connect to device %s: cannot open eAPI port", self.name) self.established = bool(self.is_online and self.hw_model) async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None: - """ - Copy files to and from the device using asyncssh.scp(). + """Copy files to and from the device using asyncssh.scp(). Args: + ---- sources: List of files to copy to or from the device. destination: Local or remote destination when copying the files. Can be a folder. direction: Defines if this coroutine copies files to or from the device. + """ async with asyncssh.connect( host=self._ssh_opts.host, @@ -396,22 +417,24 @@ class AsyncEOSDevice(AntaDevice): local_addr=self._ssh_opts.local_addr, options=self._ssh_opts, ) as conn: - src: Union[list[tuple[SSHClientConnection, Path]], list[Path]] - dst: Union[tuple[SSHClientConnection, Path], Path] + src: list[tuple[SSHClientConnection, Path]] | list[Path] + dst: tuple[SSHClientConnection, Path] | Path if direction == "from": src = [(conn, file) for file in sources] dst = destination for file in sources: - logger.info(f"Copying '{file}' from device {self.name} to '{destination}' locally") + message = f"Copying '{file}' from device {self.name} to '{destination}' locally" + logger.info(message) elif direction == "to": src = sources dst = conn, destination for file in src: - logger.info(f"Copying '{file}' to device {self.name} to '{destination}' remotely") + message = f"Copying '{file}' to device {self.name} to '{destination}' remotely" + logger.info(message) else: - logger.critical(f"'direction' argument to copy() fonction is invalid: {direction}") + logger.critical("'direction' argument to copy() function is invalid: %s", direction) return await asyncssh.scp(src, dst) |