summaryrefslogtreecommitdiffstats
path: root/pre_commit/yaml_rewrite.py
blob: 8d0e8fdb27578a10d364c1e1b613a91958253c99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

from collections.abc import Generator
from collections.abc import Iterable
from typing import NamedTuple
from typing import Protocol

from yaml.nodes import MappingNode
from yaml.nodes import Node
from yaml.nodes import ScalarNode
from yaml.nodes import SequenceNode


class _Matcher(Protocol):
    def match(self, n: Node) -> Generator[Node]: ...


class MappingKey(NamedTuple):
    k: str

    def match(self, n: Node) -> Generator[Node]:
        if isinstance(n, MappingNode):
            for k, _ in n.value:
                if k.value == self.k:
                    yield k


class MappingValue(NamedTuple):
    k: str

    def match(self, n: Node) -> Generator[Node]:
        if isinstance(n, MappingNode):
            for k, v in n.value:
                if k.value == self.k:
                    yield v


class SequenceItem(NamedTuple):
    def match(self, n: Node) -> Generator[Node]:
        if isinstance(n, SequenceNode):
            yield from n.value


def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
    return (n for src in gen for n in m.match(src))


def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
    gen: Iterable[Node] = (n,)
    for m in matcher:
        gen = _match(gen, m)
    return (n for n in gen if isinstance(n, ScalarNode))