summaryrefslogtreecommitdiffstats
path: root/test/lib/ansible_test/_internal/commands/coverage/analyze/targets/filter.py
blob: 29a8ee5b8126427f52f6070d1571a971ee2af0e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Filter an aggregated coverage file, keeping only the specified targets."""
from __future__ import annotations

import collections.abc as c
import re
import typing as t

from .....executor import (
    Delegate,
)

from .....provisioning import (
    prepare_profiles,
)

from . import (
    CoverageAnalyzeTargetsConfig,
    expand_indexes,
    generate_indexes,
    make_report,
    read_report,
    write_report,
)

from . import (
    NamedPoints,
    TargetKey,
    TargetIndexes,
)


class CoverageAnalyzeTargetsFilterConfig(CoverageAnalyzeTargetsConfig):
    """Configuration for the `coverage analyze targets filter` command."""
    def __init__(self, args: t.Any) -> None:
        super().__init__(args)

        self.input_file: str = args.input_file
        self.output_file: str = args.output_file
        self.include_targets: list[str] = args.include_targets
        self.exclude_targets: list[str] = args.exclude_targets
        self.include_path: t.Optional[str] = args.include_path
        self.exclude_path: t.Optional[str] = args.exclude_path


def command_coverage_analyze_targets_filter(args: CoverageAnalyzeTargetsFilterConfig) -> None:
    """Filter target names in an aggregated coverage file."""
    host_state = prepare_profiles(args)  # coverage analyze targets filter

    if args.delegate:
        raise Delegate(host_state=host_state)

    covered_targets, covered_path_arcs, covered_path_lines = read_report(args.input_file)

    def pass_target_key(value: TargetKey) -> TargetKey:
        """Return the given target key unmodified."""
        return value

    filtered_path_arcs = expand_indexes(covered_path_arcs, covered_targets, pass_target_key)
    filtered_path_lines = expand_indexes(covered_path_lines, covered_targets, pass_target_key)

    include_targets = set(args.include_targets) if args.include_targets else None
    exclude_targets = set(args.exclude_targets) if args.exclude_targets else None

    include_path = re.compile(args.include_path) if args.include_path else None
    exclude_path = re.compile(args.exclude_path) if args.exclude_path else None

    def path_filter_func(path: str) -> bool:
        """Return True if the given path should be included, otherwise return False."""
        if include_path and not re.search(include_path, path):
            return False

        if exclude_path and re.search(exclude_path, path):
            return False

        return True

    def target_filter_func(targets: set[str]) -> set[str]:
        """Filter the given targets and return the result based on the defined includes and excludes."""
        if include_targets:
            targets &= include_targets

        if exclude_targets:
            targets -= exclude_targets

        return targets

    filtered_path_arcs = filter_data(filtered_path_arcs, path_filter_func, target_filter_func)
    filtered_path_lines = filter_data(filtered_path_lines, path_filter_func, target_filter_func)

    target_indexes: TargetIndexes = {}
    indexed_path_arcs = generate_indexes(target_indexes, filtered_path_arcs)
    indexed_path_lines = generate_indexes(target_indexes, filtered_path_lines)

    report = make_report(target_indexes, indexed_path_arcs, indexed_path_lines)

    write_report(args, report, args.output_file)


def filter_data(
    data: NamedPoints,
    path_filter_func: c.Callable[[str], bool],
    target_filter_func: c.Callable[[set[str]], set[str]],
) -> NamedPoints:
    """Filter the data set using the specified filter function."""
    result: NamedPoints = {}

    for src_path, src_points in data.items():
        if not path_filter_func(src_path):
            continue

        dst_points = {}

        for src_point, src_targets in src_points.items():
            dst_targets = target_filter_func(src_targets)

            if dst_targets:
                dst_points[src_point] = dst_targets

        if dst_points:
            result[src_path] = dst_points

    return result