diff options
Diffstat (limited to 'anta/catalog.py')
-rw-r--r-- | anta/catalog.py | 145 |
1 files changed, 117 insertions, 28 deletions
diff --git a/anta/catalog.py b/anta/catalog.py index a04e159..142640e 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -7,11 +7,14 @@ from __future__ import annotations import importlib import logging +import math +from collections import defaultdict from inspect import isclass from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union -from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_validator +import yaml +from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_serializer, model_validator from pydantic.types import ImportString from pydantic_core import PydanticCustomError from yaml import YAMLError, safe_load @@ -43,6 +46,22 @@ class AntaTestDefinition(BaseModel): test: type[AntaTest] inputs: AntaTest.Input + @model_serializer() + def serialize_model(self) -> dict[str, AntaTest.Input]: + """Serialize the AntaTestDefinition model. + + The dictionary representing the model will be look like: + ``` + <AntaTest subclass name>: + <AntaTest.Input compliant dictionary> + ``` + + Returns + ------- + A dictionary representing the model. + """ + return {self.test.__name__: self.inputs} + def __init__(self, **data: type[AntaTest] | AntaTest.Input | dict[str, Any] | None) -> None: """Inject test in the context to allow to instantiate Input in the BeforeValidator. @@ -157,12 +176,12 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition] if isinstance(tests, dict): # This is an inner Python module modules.update(AntaCatalogFile.flatten_modules(data=tests, package=module.__name__)) - else: - if not isinstance(tests, list): - msg = f"Syntax error when parsing: {tests}\nIt must be a list of ANTA tests. Check the test catalog." - raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError + elif isinstance(tests, list): # This is a list of AntaTestDefinition modules[module] = tests + else: + msg = f"Syntax error when parsing: {tests}\nIt must be a list of ANTA tests. Check the test catalog." + raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError return modules # ANN401 - Any ok for this validator as we are validating the received data @@ -177,10 +196,15 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition] with provided value to validate test inputs. """ if isinstance(data, dict): + if not data: + return data typed_data: dict[ModuleType, list[Any]] = AntaCatalogFile.flatten_modules(data) for module, tests in typed_data.items(): test_definitions: list[AntaTestDefinition] = [] for test_definition in tests: + if isinstance(test_definition, AntaTestDefinition): + test_definitions.append(test_definition) + continue if not isinstance(test_definition, dict): msg = f"Syntax error when parsing: {test_definition}\nIt must be a dictionary. Check the test catalog." raise ValueError(msg) # noqa: TRY004 pydantic catches ValueError or AssertionError, no TypeError @@ -200,7 +224,21 @@ class AntaCatalogFile(RootModel[dict[ImportString[Any], list[AntaTestDefinition] raise ValueError(msg) test_definitions.append(AntaTestDefinition(test=test, inputs=test_inputs)) typed_data[module] = test_definitions - return typed_data + return typed_data + return data + + def yaml(self) -> str: + """Return a YAML representation string of this model. + + Returns + ------- + The YAML representation string of this model. + """ + # TODO: Pydantic and YAML serialization/deserialization is not supported natively. + # This could be improved. + # https://github.com/pydantic/pydantic/issues/1043 + # Explore if this worth using this: https://github.com/NowanIlfideme/pydantic-yaml + return yaml.safe_dump(yaml.safe_load(self.model_dump_json(serialize_as_any=True, exclude_unset=True)), indent=2, width=math.inf) class AntaCatalog: @@ -232,6 +270,12 @@ class AntaCatalog: else: self._filename = Path(filename) + # Default indexes for faster access + self.tag_to_tests: defaultdict[str | None, set[AntaTestDefinition]] = defaultdict(set) + self.tests_without_tags: set[AntaTestDefinition] = set() + self.indexes_built: bool = False + self.final_tests_count: int = 0 + @property def filename(self) -> Path | None: """Path of the file used to create this AntaCatalog instance.""" @@ -297,7 +341,7 @@ class AntaCatalog: raise TypeError(msg) try: - catalog_data = AntaCatalogFile(**data) # type: ignore[arg-type] + catalog_data = AntaCatalogFile(data) # type: ignore[arg-type] except ValidationError as e: anta_log_exception( e, @@ -328,40 +372,85 @@ class AntaCatalog: raise return AntaCatalog(tests) - def get_tests_by_tags(self, tags: set[str], *, strict: bool = False) -> list[AntaTestDefinition]: - """Return all the tests that have matching tags in their input filters. - - If strict=True, return only tests that match all the tags provided as input. - If strict=False, return all the tests that match at least one tag provided as input. + def merge(self, catalog: AntaCatalog) -> AntaCatalog: + """Merge two AntaCatalog instances. Args: ---- - tags: Tags of the tests to get. - strict: Specify if the returned tests must match all the tags provided. + catalog: AntaCatalog instance to merge to this instance. Returns ------- - List of AntaTestDefinition that match the tags + A new AntaCatalog instance containing the tests of the two instances. """ - result: list[AntaTestDefinition] = [] + return AntaCatalog(tests=self.tests + catalog.tests) + + def dump(self) -> AntaCatalogFile: + """Return an AntaCatalogFile instance from this AntaCatalog instance. + + Returns + ------- + An AntaCatalogFile instance containing tests of this AntaCatalog instance. + """ + root: dict[ImportString[Any], list[AntaTestDefinition]] = {} for test in self.tests: - if test.inputs.filters and (f := test.inputs.filters.tags): - if strict: - if all(t in tags for t in f): - result.append(test) - elif any(t in tags for t in f): - result.append(test) - return result + # Cannot use AntaTest.module property as the class is not instantiated + root.setdefault(test.test.__module__, []).append(test) + return AntaCatalogFile(root=root) - def get_tests_by_names(self, names: set[str]) -> list[AntaTestDefinition]: - """Return all the tests that have matching a list of tests names. + def build_indexes(self, filtered_tests: set[str] | None = None) -> None: + """Indexes tests by their tags for quick access during filtering operations. + + If a `filtered_tests` set is provided, only the tests in this set will be indexed. + + This method populates two attributes: + - tag_to_tests: A dictionary mapping each tag to a set of tests that contain it. + - tests_without_tags: A set of tests that do not have any tags. + + Once the indexes are built, the `indexes_built` attribute is set to True. + """ + for test in self.tests: + # Skip tests that are not in the specified filtered_tests set + if filtered_tests and test.test.name not in filtered_tests: + continue + + # Indexing by tag + if test.inputs.filters and (test_tags := test.inputs.filters.tags): + for tag in test_tags: + self.tag_to_tests[tag].add(test) + else: + self.tests_without_tags.add(test) + + self.tag_to_tests[None] = self.tests_without_tags + self.indexes_built = True + + def get_tests_by_tags(self, tags: set[str], *, strict: bool = False) -> set[AntaTestDefinition]: + """Return all tests that match a given set of tags, according to the specified strictness. Args: ---- - names: Names of the tests to get. + tags: The tags to filter tests by. If empty, return all tests without tags. + strict: If True, returns only tests that contain all specified tags (intersection). + If False, returns tests that contain any of the specified tags (union). Returns ------- - List of AntaTestDefinition that match the names + set[AntaTestDefinition]: A set of tests that match the given tags. + + Raises + ------ + ValueError: If the indexes have not been built prior to method call. """ - return [test for test in self.tests if test.test.name in names] + if not self.indexes_built: + msg = "Indexes have not been built yet. Call build_indexes() first." + raise ValueError(msg) + if not tags: + return self.tag_to_tests[None] + + filtered_sets = [self.tag_to_tests[tag] for tag in tags if tag in self.tag_to_tests] + if not filtered_sets: + return set() + + if strict: + return set.intersection(*filtered_sets) + return set.union(*filtered_sets) |