summaryrefslogtreecommitdiffstats
path: root/anta/decorators.py
blob: dc57e13ec55d67e8158a5cb89b1c47d1bbf8a34f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""decorators for tests."""

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

from anta.models import AntaTest, logger

if TYPE_CHECKING:
    from anta.result_manager.models import TestResult

# TODO: should probably use mypy Awaitable in some places rather than this everywhere - @gmuloc
F = TypeVar("F", bound=Callable[..., Any])


def deprecated_test(new_tests: list[str] | None = None) -> Callable[[F], F]:
    """Return a decorator to log a message of WARNING severity when a test is deprecated.

    Args:
    ----
        new_tests: A list of new test classes that should replace the deprecated test.

    Returns
    -------
        Callable[[F], F]: A decorator that can be used to wrap test functions.

    """

    def decorator(function: F) -> F:
        """Actual decorator that logs the message.

        Args:
        ----
            function: The test function to be decorated.

        Returns
        -------
            F: The decorated function.

        """

        @wraps(function)
        async def wrapper(*args: Any, **kwargs: Any) -> Any:
            anta_test = args[0]
            if new_tests:
                new_test_names = ", ".join(new_tests)
                logger.warning("%s test is deprecated. Consider using the following new tests: %s.", anta_test.name, new_test_names)
            else:
                logger.warning("%s test is deprecated.", anta_test.name)
            return await function(*args, **kwargs)

        return cast(F, wrapper)

    return decorator


def skip_on_platforms(platforms: list[str]) -> Callable[[F], F]:
    """Return a decorator to skip a test based on the device's hardware model.

    This decorator factory generates a decorator that will check the hardware model of the device
    the test is run on. If the model is in the list of platforms specified, the test will be skipped.

    Args:
    ----
        platforms: List of hardware models on which the test should be skipped.

    Returns
    -------
        Callable[[F], F]: A decorator that can be used to wrap test functions.

    """

    def decorator(function: F) -> F:
        """Actual decorator that either runs the test or skips it based on the device's hardware model.

        Args:
        ----
            function: The test function to be decorated.

        Returns
        -------
            F: The decorated function.

        """

        @wraps(function)
        async def wrapper(*args: Any, **kwargs: Any) -> TestResult:
            """Check the device's hardware model and conditionally run or skip the test.

            This wrapper inspects the hardware model of the device the test is run on.
            If the model is in the list of specified platforms, the test is either skipped.
            """
            anta_test = args[0]

            if anta_test.result.result != "unset":
                AntaTest.update_progress()
                return anta_test.result

            if anta_test.device.hw_model in platforms:
                anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.")
                AntaTest.update_progress()
                return anta_test.result

            return await function(*args, **kwargs)

        return cast(F, wrapper)

    return decorator