diff options
Diffstat (limited to 'pre_commit/commands')
-rw-r--r-- | pre_commit/commands/migrate_config.py | 80 | ||||
-rw-r--r-- | pre_commit/commands/run.py | 6 |
2 files changed, 73 insertions, 13 deletions
diff --git a/pre_commit/commands/migrate_config.py b/pre_commit/commands/migrate_config.py index 842fb3a..c5d47a0 100644 --- a/pre_commit/commands/migrate_config.py +++ b/pre_commit/commands/migrate_config.py @@ -1,13 +1,21 @@ from __future__ import annotations -import re +import functools +import itertools import textwrap +from typing import Callable import cfgv import yaml +from yaml.nodes import ScalarNode from pre_commit.clientlib import InvalidConfigError +from pre_commit.yaml import yaml_compose from pre_commit.yaml import yaml_load +from pre_commit.yaml_rewrite import MappingKey +from pre_commit.yaml_rewrite import MappingValue +from pre_commit.yaml_rewrite import match +from pre_commit.yaml_rewrite import SequenceItem def _is_header_line(line: str) -> bool: @@ -38,16 +46,69 @@ def _migrate_map(contents: str) -> str: return contents -def _migrate_sha_to_rev(contents: str) -> str: - return re.sub(r'(\n\s+)sha:', r'\1rev:', contents) +def _preserve_style(n: ScalarNode, *, s: str) -> str: + style = n.style or '' + return f'{style}{s}{style}' -def _migrate_python_venv(contents: str) -> str: - return re.sub( - r'(\n\s+)language: python_venv\b', - r'\1language: python', - contents, +def _fix_stage(n: ScalarNode) -> str: + return _preserve_style(n, s=f'pre-{n.value}') + + +def _migrate_composed(contents: str) -> str: + tree = yaml_compose(contents) + rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = [] + + # sha -> rev + sha_to_rev_replace = functools.partial(_preserve_style, s='rev') + sha_to_rev_matcher = ( + MappingValue('repos'), + SequenceItem(), + MappingKey('sha'), + ) + for node in match(tree, sha_to_rev_matcher): + rewrites.append((node, sha_to_rev_replace)) + + # python_venv -> python + language_matcher = ( + MappingValue('repos'), + SequenceItem(), + MappingValue('hooks'), + SequenceItem(), + MappingValue('language'), ) + python_venv_replace = functools.partial(_preserve_style, s='python') + for node in match(tree, language_matcher): + if node.value == 'python_venv': + rewrites.append((node, python_venv_replace)) + + # stages rewrites + default_stages_matcher = (MappingValue('default_stages'), SequenceItem()) + default_stages_match = match(tree, default_stages_matcher) + hook_stages_matcher = ( + MappingValue('repos'), + SequenceItem(), + MappingValue('hooks'), + SequenceItem(), + MappingValue('stages'), + SequenceItem(), + ) + hook_stages_match = match(tree, hook_stages_matcher) + for node in itertools.chain(default_stages_match, hook_stages_match): + if node.value in {'commit', 'push', 'merge-commit'}: + rewrites.append((node, _fix_stage)) + + rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index) + + src_parts = [] + end: int | None = None + for node, func in rewrites: + src_parts.append(contents[node.end_mark.index:end]) + src_parts.append(func(node)) + end = node.start_mark.index + src_parts.append(contents[:end]) + src_parts.reverse() + return ''.join(src_parts) def migrate_config(config_file: str, quiet: bool = False) -> int: @@ -62,8 +123,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int: raise cfgv.ValidationError(str(e)) contents = _migrate_map(contents) - contents = _migrate_sha_to_rev(contents) - contents = _migrate_python_venv(contents) + contents = _migrate_composed(contents) if contents != orig_contents: with open(config_file, 'w') as f: diff --git a/pre_commit/commands/run.py b/pre_commit/commands/run.py index 2a08dff..793adbd 100644 --- a/pre_commit/commands/run.py +++ b/pre_commit/commands/run.py @@ -61,7 +61,7 @@ def filter_by_include_exclude( names: Iterable[str], include: str, exclude: str, -) -> Generator[str, None, None]: +) -> Generator[str]: include_re, exclude_re = re.compile(include), re.compile(exclude) return ( filename for filename in names @@ -84,7 +84,7 @@ class Classifier: types: Iterable[str], types_or: Iterable[str], exclude_types: Iterable[str], - ) -> Generator[str, None, None]: + ) -> Generator[str]: types = frozenset(types) types_or = frozenset(types_or) exclude_types = frozenset(exclude_types) @@ -97,7 +97,7 @@ class Classifier: ): yield filename - def filenames_for_hook(self, hook: Hook) -> Generator[str, None, None]: + def filenames_for_hook(self, hook: Hook) -> Generator[str]: return self.by_types( filter_by_include_exclude( self.filenames, |