summaryrefslogtreecommitdiffstats
path: root/pre_commit/commands
diff options
context:
space:
mode:
Diffstat (limited to 'pre_commit/commands')
-rw-r--r--pre_commit/commands/migrate_config.py80
-rw-r--r--pre_commit/commands/run.py6
2 files changed, 73 insertions, 13 deletions
diff --git a/pre_commit/commands/migrate_config.py b/pre_commit/commands/migrate_config.py
index 842fb3a..c5d47a0 100644
--- a/pre_commit/commands/migrate_config.py
+++ b/pre_commit/commands/migrate_config.py
@@ -1,13 +1,21 @@
from __future__ import annotations
-import re
+import functools
+import itertools
import textwrap
+from typing import Callable
import cfgv
import yaml
+from yaml.nodes import ScalarNode
from pre_commit.clientlib import InvalidConfigError
+from pre_commit.yaml import yaml_compose
from pre_commit.yaml import yaml_load
+from pre_commit.yaml_rewrite import MappingKey
+from pre_commit.yaml_rewrite import MappingValue
+from pre_commit.yaml_rewrite import match
+from pre_commit.yaml_rewrite import SequenceItem
def _is_header_line(line: str) -> bool:
@@ -38,16 +46,69 @@ def _migrate_map(contents: str) -> str:
return contents
-def _migrate_sha_to_rev(contents: str) -> str:
- return re.sub(r'(\n\s+)sha:', r'\1rev:', contents)
+def _preserve_style(n: ScalarNode, *, s: str) -> str:
+ style = n.style or ''
+ return f'{style}{s}{style}'
-def _migrate_python_venv(contents: str) -> str:
- return re.sub(
- r'(\n\s+)language: python_venv\b',
- r'\1language: python',
- contents,
+def _fix_stage(n: ScalarNode) -> str:
+ return _preserve_style(n, s=f'pre-{n.value}')
+
+
+def _migrate_composed(contents: str) -> str:
+ tree = yaml_compose(contents)
+ rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = []
+
+ # sha -> rev
+ sha_to_rev_replace = functools.partial(_preserve_style, s='rev')
+ sha_to_rev_matcher = (
+ MappingValue('repos'),
+ SequenceItem(),
+ MappingKey('sha'),
+ )
+ for node in match(tree, sha_to_rev_matcher):
+ rewrites.append((node, sha_to_rev_replace))
+
+ # python_venv -> python
+ language_matcher = (
+ MappingValue('repos'),
+ SequenceItem(),
+ MappingValue('hooks'),
+ SequenceItem(),
+ MappingValue('language'),
)
+ python_venv_replace = functools.partial(_preserve_style, s='python')
+ for node in match(tree, language_matcher):
+ if node.value == 'python_venv':
+ rewrites.append((node, python_venv_replace))
+
+ # stages rewrites
+ default_stages_matcher = (MappingValue('default_stages'), SequenceItem())
+ default_stages_match = match(tree, default_stages_matcher)
+ hook_stages_matcher = (
+ MappingValue('repos'),
+ SequenceItem(),
+ MappingValue('hooks'),
+ SequenceItem(),
+ MappingValue('stages'),
+ SequenceItem(),
+ )
+ hook_stages_match = match(tree, hook_stages_matcher)
+ for node in itertools.chain(default_stages_match, hook_stages_match):
+ if node.value in {'commit', 'push', 'merge-commit'}:
+ rewrites.append((node, _fix_stage))
+
+ rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index)
+
+ src_parts = []
+ end: int | None = None
+ for node, func in rewrites:
+ src_parts.append(contents[node.end_mark.index:end])
+ src_parts.append(func(node))
+ end = node.start_mark.index
+ src_parts.append(contents[:end])
+ src_parts.reverse()
+ return ''.join(src_parts)
def migrate_config(config_file: str, quiet: bool = False) -> int:
@@ -62,8 +123,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
raise cfgv.ValidationError(str(e))
contents = _migrate_map(contents)
- contents = _migrate_sha_to_rev(contents)
- contents = _migrate_python_venv(contents)
+ contents = _migrate_composed(contents)
if contents != orig_contents:
with open(config_file, 'w') as f:
diff --git a/pre_commit/commands/run.py b/pre_commit/commands/run.py
index 2a08dff..793adbd 100644
--- a/pre_commit/commands/run.py
+++ b/pre_commit/commands/run.py
@@ -61,7 +61,7 @@ def filter_by_include_exclude(
names: Iterable[str],
include: str,
exclude: str,
-) -> Generator[str, None, None]:
+) -> Generator[str]:
include_re, exclude_re = re.compile(include), re.compile(exclude)
return (
filename for filename in names
@@ -84,7 +84,7 @@ class Classifier:
types: Iterable[str],
types_or: Iterable[str],
exclude_types: Iterable[str],
- ) -> Generator[str, None, None]:
+ ) -> Generator[str]:
types = frozenset(types)
types_or = frozenset(types_or)
exclude_types = frozenset(exclude_types)
@@ -97,7 +97,7 @@ class Classifier:
):
yield filename
- def filenames_for_hook(self, hook: Hook) -> Generator[str, None, None]:
+ def filenames_for_hook(self, hook: Hook) -> Generator[str]:
return self.by_types(
filter_by_include_exclude(
self.filenames,