diff options
Diffstat (limited to 'anta/tools.py')
-rw-r--r-- | anta/tools.py | 124 |
1 files changed, 121 insertions, 3 deletions
diff --git a/anta/tools.py b/anta/tools.py index d1d394a..b3760da 100644 --- a/anta/tools.py +++ b/anta/tools.py @@ -5,7 +5,26 @@ from __future__ import annotations -from typing import Any +import cProfile +import os +import pstats +from functools import wraps +from time import perf_counter +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast + +from anta.logger import format_td + +if TYPE_CHECKING: + import sys + from logging import Logger + from types import TracebackType + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + +F = TypeVar("F", bound=Callable[..., Any]) def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, Any]) -> str: @@ -28,14 +47,35 @@ def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, An for element, expected_data in expected_output.items(): actual_data = actual_output.get(element) + if actual_data == expected_data: + continue if actual_data is None: failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but it was not found in the actual output.") - elif actual_data != expected_data: - failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but found `{actual_data}` instead.") + continue + # actual_data != expected_data: and actual_data is not None + failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but found `{actual_data}` instead.") return "".join(failed_logs) +def custom_division(numerator: float, denominator: float) -> int | float: + """Get the custom division of numbers. + + Custom division that returns an integer if the result is an integer, otherwise a float. + + Parameters + ---------- + numerator: The numerator. + denominator: The denominator. + + Returns + ------- + Union[int, float]: The result of the division. + """ + result = numerator / denominator + return int(result) if result.is_integer() else result + + # pylint: disable=too-many-arguments def get_dict_superset( list_of_dicts: list[dict[Any, Any]], @@ -228,3 +268,81 @@ def get_item( if required is True: raise ValueError(custom_error_msg or var_name) return default + + +class Catchtime: + """A class working as a context to capture time differences.""" + + start: float + raw_time: float + time: str + + def __init__(self, logger: Logger | None = None, message: str | None = None) -> None: + self.logger = logger + self.message = message + + def __enter__(self) -> Self: + """__enter__ method.""" + self.start = perf_counter() + if self.logger and self.message: + self.logger.info("%s ...", self.message) + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: + """__exit__ method.""" + self.raw_time = perf_counter() - self.start + self.time = format_td(self.raw_time, 3) + if self.logger and self.message: + self.logger.info("%s completed in: %s.", self.message, self.time) + + +def cprofile(sort_by: str = "cumtime") -> Callable[[F], F]: + """Profile a function with cProfile. + + profile is conditionally enabled based on the presence of ANTA_CPROFILE environment variable. + Expect to decorate an async function. + + Args: + ---- + sort_by (str): The criterion to sort the profiling results. Default is 'cumtime'. + + Returns + ------- + Callable: The decorated function with conditional profiling. + """ + + def decorator(func: F) -> F: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + """Enable cProfile or not. + + If `ANTA_CPROFILE` is set, cProfile is enabled and dumps the stats to the file. + + Args: + ---- + *args: Arbitrary positional arguments. + **kwargs: Arbitrary keyword arguments. + + Returns + ------- + The result of the function call. + """ + cprofile_file = os.environ.get("ANTA_CPROFILE") + + if cprofile_file is not None: + profiler = cProfile.Profile() + profiler.enable() + + try: + result = await func(*args, **kwargs) + finally: + if cprofile_file is not None: + profiler.disable() + stats = pstats.Stats(profiler).sort_stats(sort_by) + stats.dump_stats(cprofile_file) + + return result + + return cast(F, wrapper) + + return decorator |