summaryrefslogtreecommitdiffstats
path: root/anta/tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'anta/tools.py')
-rw-r--r--anta/tools.py124
1 files changed, 121 insertions, 3 deletions
diff --git a/anta/tools.py b/anta/tools.py
index d1d394a..b3760da 100644
--- a/anta/tools.py
+++ b/anta/tools.py
@@ -5,7 +5,26 @@
from __future__ import annotations
-from typing import Any
+import cProfile
+import os
+import pstats
+from functools import wraps
+from time import perf_counter
+from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
+
+from anta.logger import format_td
+
+if TYPE_CHECKING:
+ import sys
+ from logging import Logger
+ from types import TracebackType
+
+ if sys.version_info >= (3, 11):
+ from typing import Self
+ else:
+ from typing_extensions import Self
+
+F = TypeVar("F", bound=Callable[..., Any])
def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, Any]) -> str:
@@ -28,14 +47,35 @@ def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, An
for element, expected_data in expected_output.items():
actual_data = actual_output.get(element)
+ if actual_data == expected_data:
+ continue
if actual_data is None:
failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but it was not found in the actual output.")
- elif actual_data != expected_data:
- failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but found `{actual_data}` instead.")
+ continue
+ # actual_data != expected_data: and actual_data is not None
+ failed_logs.append(f"\nExpected `{expected_data}` as the {element}, but found `{actual_data}` instead.")
return "".join(failed_logs)
+def custom_division(numerator: float, denominator: float) -> int | float:
+ """Get the custom division of numbers.
+
+ Custom division that returns an integer if the result is an integer, otherwise a float.
+
+ Parameters
+ ----------
+ numerator: The numerator.
+ denominator: The denominator.
+
+ Returns
+ -------
+ Union[int, float]: The result of the division.
+ """
+ result = numerator / denominator
+ return int(result) if result.is_integer() else result
+
+
# pylint: disable=too-many-arguments
def get_dict_superset(
list_of_dicts: list[dict[Any, Any]],
@@ -228,3 +268,81 @@ def get_item(
if required is True:
raise ValueError(custom_error_msg or var_name)
return default
+
+
+class Catchtime:
+ """A class working as a context to capture time differences."""
+
+ start: float
+ raw_time: float
+ time: str
+
+ def __init__(self, logger: Logger | None = None, message: str | None = None) -> None:
+ self.logger = logger
+ self.message = message
+
+ def __enter__(self) -> Self:
+ """__enter__ method."""
+ self.start = perf_counter()
+ if self.logger and self.message:
+ self.logger.info("%s ...", self.message)
+ return self
+
+ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None:
+ """__exit__ method."""
+ self.raw_time = perf_counter() - self.start
+ self.time = format_td(self.raw_time, 3)
+ if self.logger and self.message:
+ self.logger.info("%s completed in: %s.", self.message, self.time)
+
+
+def cprofile(sort_by: str = "cumtime") -> Callable[[F], F]:
+ """Profile a function with cProfile.
+
+ profile is conditionally enabled based on the presence of ANTA_CPROFILE environment variable.
+ Expect to decorate an async function.
+
+ Args:
+ ----
+ sort_by (str): The criterion to sort the profiling results. Default is 'cumtime'.
+
+ Returns
+ -------
+ Callable: The decorated function with conditional profiling.
+ """
+
+ def decorator(func: F) -> F:
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ """Enable cProfile or not.
+
+ If `ANTA_CPROFILE` is set, cProfile is enabled and dumps the stats to the file.
+
+ Args:
+ ----
+ *args: Arbitrary positional arguments.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns
+ -------
+ The result of the function call.
+ """
+ cprofile_file = os.environ.get("ANTA_CPROFILE")
+
+ if cprofile_file is not None:
+ profiler = cProfile.Profile()
+ profiler.enable()
+
+ try:
+ result = await func(*args, **kwargs)
+ finally:
+ if cprofile_file is not None:
+ profiler.disable()
+ stats = pstats.Stats(profiler).sort_stats(sort_by)
+ stats.dump_stats(cprofile_file)
+
+ return result
+
+ return cast(F, wrapper)
+
+ return decorator