summaryrefslogtreecommitdiffstats
path: root/anta/catalog.py
diff options
context:
space:
mode:
Diffstat (limited to 'anta/catalog.py')
-rw-r--r--anta/catalog.py145
1 files changed, 117 insertions, 28 deletions
diff --git a/anta/catalog.py b/anta/catalog.py
index a04e159..142640e 100644
--- a/anta/catalog.py
+++ b/anta/catalog.py
@@ -7,11 +7,14 @@ from __future__ import annotations
import importlib
import logging
+import math
+from collections import defaultdict
from inspect import isclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
-from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_validator
+import yaml
+from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_serializer, model_validator
from pydantic.types import ImportString
from pydantic_core import PydanticCustomError
from yaml import YAMLError, safe_load
@@ -43,6 +46,22 @@ class AntaTestDefinition(BaseModel):
test: type[AntaTest]
inputs: AntaTest.Input
+ @model_serializer()
+ def serialize_model(self) -> dict[str, AntaTest.Input]:
+ """Serialize the AntaTestDefinition model.
+
+ The dictionary representing the model will be look like:
+ ```
+ <AntaTest subclass name>:
+ <AntaTest.Input compliant dictionary>
+ ```
+
+ Returns
+ -------
+ A dictionary representing the model.
+ """
+ return {self.test.__name__: self.inputs}
+
def __init__(self, **data: type[AntaTest] | AntaTest.Input | dict[str, Any] | None) -> None:
"""Inject test in the context to allow to instantiate Input in the BeforeValidator.
@@ -157,12 +176,12 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition]
if isinstance(tests, dict):
# This is an inner Python module
modules.update(AntaCatalogFile.flatten_modules(data=tests, package=module.__name__))
- else:
- if not isinstance(tests, list):
- msg = f"Syntax error when parsing: {tests}\nIt must be a list of ANTA tests. Check the test catalog."
- raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError
+ elif isinstance(tests, list):
# This is a list of AntaTestDefinition
modules[module] = tests
+ else:
+ msg = f"Syntax error when parsing: {tests}\nIt must be a list of ANTA tests. Check the test catalog."
+ raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError
return modules
# ANN401 - Any ok for this validator as we are validating the received data
@@ -177,10 +196,15 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition]
with provided value to validate test inputs.
"""
if isinstance(data, dict):
+ if not data:
+ return data
typed_data: dict[ModuleType, list[Any]] = AntaCatalogFile.flatten_modules(data)
for module, tests in typed_data.items():
test_definitions: list[AntaTestDefinition] = []
for test_definition in tests:
+ if isinstance(test_definition, AntaTestDefinition):
+ test_definitions.append(test_definition)
+ continue
if not isinstance(test_definition, dict):
msg = f"Syntax error when parsing: {test_definition}\nIt must be a dictionary. Check the test catalog."
raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError
@@ -200,7 +224,21 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition]
raise ValueError(msg)
test_definitions.append(AntaTestDefinition(test=test, inputs=test_inputs))
typed_data[module] = test_definitions
- return typed_data
+ return typed_data
+ return data
+
+ def yaml(self) -> str:
+ """Return a YAML representation string of this model.
+
+ Returns
+ -------
+ The YAML representation string of this model.
+ """
+ # TODO: Pydantic and YAML serialization/deserialization is not supported natively.
+ # This could be improved.
+ # https://github.com/pydantic/pydantic/issues/1043
+ # Explore if this worth using this: https://github.com/NowanIlfideme/pydantic-yaml
+ return yaml.safe_dump(yaml.safe_load(self.model_dump_json(serialize_as_any=True, exclude_unset=True)), indent=2, width=math.inf)
class AntaCatalog:
@@ -232,6 +270,12 @@ class AntaCatalog:
else:
self._filename = Path(filename)
+ # Default indexes for faster access
+ self.tag_to_tests: defaultdict[str | None, set[AntaTestDefinition]] = defaultdict(set)
+ self.tests_without_tags: set[AntaTestDefinition] = set()
+ self.indexes_built: bool = False
+ self.final_tests_count: int = 0
+
@property
def filename(self) -> Path | None:
"""Path of the file used to create this AntaCatalog instance."""
@@ -297,7 +341,7 @@ class AntaCatalog:
raise TypeError(msg)
try:
- catalog_data = AntaCatalogFile(**data) # type: ignore[arg-type]
+ catalog_data = AntaCatalogFile(data) # type: ignore[arg-type]
except ValidationError as e:
anta_log_exception(
e,
@@ -328,40 +372,85 @@ class AntaCatalog:
raise
return AntaCatalog(tests)
- def get_tests_by_tags(self, tags: set[str], *, strict: bool = False) -> list[AntaTestDefinition]:
- """Return all the tests that have matching tags in their input filters.
-
- If strict=True, return only tests that match all the tags provided as input.
- If strict=False, return all the tests that match at least one tag provided as input.
+ def merge(self, catalog: AntaCatalog) -> AntaCatalog:
+ """Merge two AntaCatalog instances.
Args:
----
- tags: Tags of the tests to get.
- strict: Specify if the returned tests must match all the tags provided.
+ catalog: AntaCatalog instance to merge to this instance.
Returns
-------
- List of AntaTestDefinition that match the tags
+ A new AntaCatalog instance containing the tests of the two instances.
"""
- result: list[AntaTestDefinition] = []
+ return AntaCatalog(tests=self.tests + catalog.tests)
+
+ def dump(self) -> AntaCatalogFile:
+ """Return an AntaCatalogFile instance from this AntaCatalog instance.
+
+ Returns
+ -------
+ An AntaCatalogFile instance containing tests of this AntaCatalog instance.
+ """
+ root: dict[ImportString[Any], list[AntaTestDefinition]] = {}
for test in self.tests:
- if test.inputs.filters and (f := test.inputs.filters.tags):
- if strict:
- if all(t in tags for t in f):
- result.append(test)
- elif any(t in tags for t in f):
- result.append(test)
- return result
+ # Cannot use AntaTest.module property as the class is not instantiated
+ root.setdefault(test.test.__module__, []).append(test)
+ return AntaCatalogFile(root=root)
- def get_tests_by_names(self, names: set[str]) -> list[AntaTestDefinition]:
- """Return all the tests that have matching a list of tests names.
+ def build_indexes(self, filtered_tests: set[str] | None = None) -> None:
+ """Indexes tests by their tags for quick access during filtering operations.
+
+ If a `filtered_tests` set is provided, only the tests in this set will be indexed.
+
+ This method populates two attributes:
+ - tag_to_tests: A dictionary mapping each tag to a set of tests that contain it.
+ - tests_without_tags: A set of tests that do not have any tags.
+
+ Once the indexes are built, the `indexes_built` attribute is set to True.
+ """
+ for test in self.tests:
+ # Skip tests that are not in the specified filtered_tests set
+ if filtered_tests and test.test.name not in filtered_tests:
+ continue
+
+ # Indexing by tag
+ if test.inputs.filters and (test_tags := test.inputs.filters.tags):
+ for tag in test_tags:
+ self.tag_to_tests[tag].add(test)
+ else:
+ self.tests_without_tags.add(test)
+
+ self.tag_to_tests[None] = self.tests_without_tags
+ self.indexes_built = True
+
+ def get_tests_by_tags(self, tags: set[str], *, strict: bool = False) -> set[AntaTestDefinition]:
+ """Return all tests that match a given set of tags, according to the specified strictness.
Args:
----
- names: Names of the tests to get.
+ tags: The tags to filter tests by. If empty, return all tests without tags.
+ strict: If True, returns only tests that contain all specified tags (intersection).
+ If False, returns tests that contain any of the specified tags (union).
Returns
-------
- List of AntaTestDefinition that match the names
+ set[AntaTestDefinition]: A set of tests that match the given tags.
+
+ Raises
+ ------
+ ValueError: If the indexes have not been built prior to method call.
"""
- return [test for test in self.tests if test.test.name in names]
+ if not self.indexes_built:
+ msg = "Indexes have not been built yet. Call build_indexes() first."
+ raise ValueError(msg)
+ if not tags:
+ return self.tag_to_tests[None]
+
+ filtered_sets = [self.tag_to_tests[tag] for tag in tags if tag in self.tag_to_tests]
+ if not filtered_sets:
+ return set()
+
+ if strict:
+ return set.intersection(*filtered_sets)
+ return set.union(*filtered_sets)