summaryrefslogtreecommitdiffstats
path: root/pre_commit/yaml_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'pre_commit/yaml_rewrite.py')
-rw-r--r--pre_commit/yaml_rewrite.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/pre_commit/yaml_rewrite.py b/pre_commit/yaml_rewrite.py
new file mode 100644
index 0000000..8d0e8fd
--- /dev/null
+++ b/pre_commit/yaml_rewrite.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+from collections.abc import Generator
+from collections.abc import Iterable
+from typing import NamedTuple
+from typing import Protocol
+
+from yaml.nodes import MappingNode
+from yaml.nodes import Node
+from yaml.nodes import ScalarNode
+from yaml.nodes import SequenceNode
+
+
+class _Matcher(Protocol):
+ def match(self, n: Node) -> Generator[Node]: ...
+
+
+class MappingKey(NamedTuple):
+ k: str
+
+ def match(self, n: Node) -> Generator[Node]:
+ if isinstance(n, MappingNode):
+ for k, _ in n.value:
+ if k.value == self.k:
+ yield k
+
+
+class MappingValue(NamedTuple):
+ k: str
+
+ def match(self, n: Node) -> Generator[Node]:
+ if isinstance(n, MappingNode):
+ for k, v in n.value:
+ if k.value == self.k:
+ yield v
+
+
+class SequenceItem(NamedTuple):
+ def match(self, n: Node) -> Generator[Node]:
+ if isinstance(n, SequenceNode):
+ yield from n.value
+
+
+def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
+ return (n for src in gen for n in m.match(src))
+
+
+def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
+ gen: Iterable[Node] = (n,)
+ for m in matcher:
+ gen = _match(gen, m)
+ return (n for n in gen if isinstance(n, ScalarNode))