diff options
Diffstat (limited to 'pre_commit_hooks')
35 files changed, 2426 insertions, 0 deletions
diff --git a/pre_commit_hooks/__init__.py b/pre_commit_hooks/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pre_commit_hooks/__init__.py diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py new file mode 100644 index 0000000..9e0619b --- /dev/null +++ b/pre_commit_hooks/check_added_large_files.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import argparse +import math +import os +import subprocess +from typing import Sequence + +from pre_commit_hooks.util import added_files +from pre_commit_hooks.util import zsplit + + +def filter_lfs_files(filenames: set[str]) -> None: # pragma: no cover (lfs) + """Remove files tracked by git-lfs from the set.""" + if not filenames: + return + + check_attr = subprocess.run( + ('git', 'check-attr', 'filter', '-z', '--stdin'), + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + encoding='utf-8', + check=True, + input='\0'.join(filenames), + ) + stdout = zsplit(check_attr.stdout) + for i in range(0, len(stdout), 3): + filename, filter_tag = stdout[i], stdout[i + 2] + if filter_tag == 'lfs': + filenames.remove(filename) + + +def find_large_added_files( + filenames: Sequence[str], + maxkb: int, + *, + enforce_all: bool = False, +) -> int: + # Find all added files that are also in the list of files pre-commit tells + # us about + retv = 0 + filenames_filtered = set(filenames) + filter_lfs_files(filenames_filtered) + + if not enforce_all: + filenames_filtered &= added_files() + + for filename in filenames_filtered: + kb = math.ceil(os.stat(filename).st_size / 1024) + if kb > maxkb: + print(f'{filename} ({kb} KB) exceeds {maxkb} KB.') + retv = 1 + + return retv + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + 'filenames', nargs='*', + help='Filenames pre-commit believes are changed.', + ) + parser.add_argument( + '--enforce-all', action='store_true', + help='Enforce all files are checked, not just staged files.', + ) + parser.add_argument( + '--maxkb', type=int, default=500, + help='Maximum allowable KB for added files', + ) + args = parser.parse_args(argv) + + return find_large_added_files( + args.filenames, + args.maxkb, + enforce_all=args.enforce_all, + ) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py new file mode 100644 index 0000000..fdac361 --- /dev/null +++ b/pre_commit_hooks/check_ast.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import argparse +import ast +import platform +import sys +import traceback +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + args = parser.parse_args(argv) + + retval = 0 + for filename in args.filenames: + + try: + with open(filename, 'rb') as f: + ast.parse(f.read(), filename=filename) + except SyntaxError: + impl = platform.python_implementation() + version = sys.version.split()[0] + print(f'{filename}: failed parsing with {impl} {version}:') + tb = ' ' + traceback.format_exc().replace('\n', '\n ') + print(f'\n{tb}') + retval = 1 + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py new file mode 100644 index 0000000..d3054aa --- /dev/null +++ b/pre_commit_hooks/check_builtin_literals.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import argparse +import ast +from typing import NamedTuple +from typing import Sequence + + +BUILTIN_TYPES = { + 'complex': '0j', + 'dict': '{}', + 'float': '0.0', + 'int': '0', + 'list': '[]', + 'str': "''", + 'tuple': '()', +} + + +class Call(NamedTuple): + name: str + line: int + column: int + + +class Visitor(ast.NodeVisitor): + def __init__( + self, + ignore: Sequence[str] | None = None, + allow_dict_kwargs: bool = True, + ) -> None: + self.builtin_type_calls: list[Call] = [] + self.ignore = set(ignore) if ignore else set() + self.allow_dict_kwargs = allow_dict_kwargs + + def _check_dict_call(self, node: ast.Call) -> bool: + return self.allow_dict_kwargs and bool(node.keywords) + + def visit_Call(self, node: ast.Call) -> None: + if not isinstance(node.func, ast.Name): + # Ignore functions that are object attributes (`foo.bar()`). + # Assume that if the user calls `builtins.list()`, they know what + # they're doing. + return + if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore): + return + if node.func.id == 'dict' and self._check_dict_call(node): + return + elif node.args: + return + self.builtin_type_calls.append( + Call(node.func.id, node.lineno, node.col_offset), + ) + + +def check_file( + filename: str, + ignore: Sequence[str] | None = None, + allow_dict_kwargs: bool = True, +) -> list[Call]: + with open(filename, 'rb') as f: + tree = ast.parse(f.read(), filename=filename) + visitor = Visitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs) + visitor.visit(tree) + return visitor.builtin_type_calls + + +def parse_ignore(value: str) -> set[str]: + return set(value.split(',')) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + parser.add_argument('--ignore', type=parse_ignore, default=set()) + + mutex = parser.add_mutually_exclusive_group(required=False) + mutex.add_argument('--allow-dict-kwargs', action='store_true') + mutex.add_argument( + '--no-allow-dict-kwargs', + dest='allow_dict_kwargs', action='store_false', + ) + mutex.set_defaults(allow_dict_kwargs=True) + + args = parser.parse_args(argv) + + rc = 0 + for filename in args.filenames: + calls = check_file( + filename, + ignore=args.ignore, + allow_dict_kwargs=args.allow_dict_kwargs, + ) + if calls: + rc = rc or 1 + for call in calls: + print( + f'{filename}:{call.line}:{call.column}: ' + f'replace {call.name}() with {BUILTIN_TYPES[call.name]}', + ) + return rc + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py new file mode 100644 index 0000000..59cc561 --- /dev/null +++ b/pre_commit_hooks/check_byte_order_marker.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import argparse +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + with open(filename, 'rb') as f: + if f.read(3) == b'\xef\xbb\xbf': + retv = 1 + print(f'{filename}: Has a byte-order marker') + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py new file mode 100644 index 0000000..33a13f1 --- /dev/null +++ b/pre_commit_hooks/check_case_conflict.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import argparse +from typing import Iterable +from typing import Iterator +from typing import Sequence + +from pre_commit_hooks.util import added_files +from pre_commit_hooks.util import cmd_output + + +def lower_set(iterable: Iterable[str]) -> set[str]: + return {x.lower() for x in iterable} + + +def parents(file: str) -> Iterator[str]: + path_parts = file.split('/') + path_parts.pop() + while path_parts: + yield '/'.join(path_parts) + path_parts.pop() + + +def directories_for(files: set[str]) -> set[str]: + return {parent for file in files for parent in parents(file)} + + +def find_conflicting_filenames(filenames: Sequence[str]) -> int: + repo_files = set(cmd_output('git', 'ls-files').splitlines()) + repo_files |= directories_for(repo_files) + relevant_files = set(filenames) | added_files() + relevant_files |= directories_for(relevant_files) + repo_files -= relevant_files + retv = 0 + + # new file conflicts with existing file + conflicts = lower_set(repo_files) & lower_set(relevant_files) + + # new file conflicts with other new file + lowercase_relevant_files = lower_set(relevant_files) + for filename in set(relevant_files): + if filename.lower() in lowercase_relevant_files: + lowercase_relevant_files.remove(filename.lower()) + else: + conflicts.add(filename.lower()) + + if conflicts: + conflicting_files = [ + x for x in repo_files | relevant_files + if x.lower() in conflicts + ] + for filename in sorted(conflicting_files): + print(f'Case-insensitivity conflict found: {filename}') + retv = 1 + + return retv + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + 'filenames', nargs='*', + help='Filenames pre-commit believes are changed.', + ) + + args = parser.parse_args(argv) + + return find_conflicting_filenames(args.filenames) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py new file mode 100644 index 0000000..d55f08a --- /dev/null +++ b/pre_commit_hooks/check_docstring_first.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import argparse +import io +import tokenize +from tokenize import tokenize as tokenize_tokenize +from typing import Sequence + +NON_CODE_TOKENS = frozenset(( + tokenize.COMMENT, tokenize.ENDMARKER, tokenize.NEWLINE, tokenize.NL, + tokenize.ENCODING, +)) + + +def check_docstring_first(src: bytes, filename: str = '<unknown>') -> int: + """Returns nonzero if the source has what looks like a docstring that is + not at the beginning of the source. + + A string will be considered a docstring if it is a STRING token with a + col offset of 0. + """ + found_docstring_line = None + found_code_line = None + + tok_gen = tokenize_tokenize(io.BytesIO(src).readline) + for tok_type, _, (sline, scol), _, _ in tok_gen: + # Looks like a docstring! + if tok_type == tokenize.STRING and scol == 0: + if found_docstring_line is not None: + print( + f'{filename}:{sline}: Multiple module docstrings ' + f'(first docstring on line {found_docstring_line}).', + ) + return 1 + elif found_code_line is not None: + print( + f'{filename}:{sline}: Module docstring appears after code ' + f'(code seen on line {found_code_line}).', + ) + return 1 + else: + found_docstring_line = sline + elif tok_type not in NON_CODE_TOKENS and found_code_line is None: + found_code_line = sline + + return 0 + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + with open(filename, 'rb') as f: + contents = f.read() + retv |= check_docstring_first(contents, filename=filename) + + return retv diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py new file mode 100644 index 0000000..d8e4f49 --- /dev/null +++ b/pre_commit_hooks/check_executables_have_shebangs.py @@ -0,0 +1,85 @@ +"""Check that executable text files have a shebang.""" +from __future__ import annotations + +import argparse +import shlex +import sys +from typing import Generator +from typing import NamedTuple +from typing import Sequence + +from pre_commit_hooks.util import cmd_output +from pre_commit_hooks.util import zsplit + +EXECUTABLE_VALUES = frozenset(('1', '3', '5', '7')) + + +def check_executables(paths: list[str]) -> int: + fs_tracks_executable_bit = cmd_output( + 'git', 'config', 'core.fileMode', retcode=None, + ).strip() + if fs_tracks_executable_bit == 'false': # pragma: win32 cover + return _check_git_filemode(paths) + else: # pragma: win32 no cover + retv = 0 + for path in paths: + if not has_shebang(path): + _message(path) + retv = 1 + + return retv + + +class GitLsFile(NamedTuple): + mode: str + filename: str + + +def git_ls_files(paths: Sequence[str]) -> Generator[GitLsFile, None, None]: + outs = cmd_output('git', 'ls-files', '-z', '--stage', '--', *paths) + for out in zsplit(outs): + metadata, filename = out.split('\t') + mode, _, _ = metadata.split() + yield GitLsFile(mode, filename) + + +def _check_git_filemode(paths: Sequence[str]) -> int: + seen: set[str] = set() + for ls_file in git_ls_files(paths): + is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:]) + if is_executable and not has_shebang(ls_file.filename): + _message(ls_file.filename) + seen.add(ls_file.filename) + + return int(bool(seen)) + + +def has_shebang(path: str) -> int: + with open(path, 'rb') as f: + first_bytes = f.read(2) + + return first_bytes == b'#!' + + +def _message(path: str) -> None: + print( + f'{path}: marked executable but has no (or invalid) shebang!\n' + f" If it isn't supposed to be executable, try: " + f'`chmod -x {shlex.quote(path)}`\n' + f' If on Windows, you may also need to: ' + f'`git add --chmod=-x {shlex.quote(path)}`\n' + f' If it is supposed to be executable, double-check its shebang.', + file=sys.stderr, + ) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('filenames', nargs='*') + args = parser.parse_args(argv) + + return check_executables(args.filenames) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py new file mode 100644 index 0000000..6a679fe --- /dev/null +++ b/pre_commit_hooks/check_json.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import argparse +import json +from typing import Any +from typing import Sequence + + +def raise_duplicate_keys( + ordered_pairs: list[tuple[str, Any]], +) -> dict[str, Any]: + d = {} + for key, val in ordered_pairs: + if key in d: + raise ValueError(f'Duplicate key: {key}') + else: + d[key] = val + return d + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check.') + args = parser.parse_args(argv) + + retval = 0 + for filename in args.filenames: + with open(filename, 'rb') as f: + try: + json.load(f, object_pairs_hook=raise_duplicate_keys) + except ValueError as exc: + print(f'{filename}: Failed to json decode ({exc})') + retval = 1 + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py new file mode 100644 index 0000000..15ec284 --- /dev/null +++ b/pre_commit_hooks/check_merge_conflict.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import argparse +import os.path +from typing import Sequence + +from pre_commit_hooks.util import cmd_output + + +CONFLICT_PATTERNS = [ + b'<<<<<<< ', + b'======= ', + b'=======\r\n', + b'=======\n', + b'>>>>>>> ', +] + + +def is_in_merge() -> bool: + git_dir = cmd_output('git', 'rev-parse', '--git-dir').rstrip() + return ( + os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and + ( + os.path.exists(os.path.join(git_dir, 'MERGE_HEAD')) or + os.path.exists(os.path.join(git_dir, 'rebase-apply')) or + os.path.exists(os.path.join(git_dir, 'rebase-merge')) + ) + ) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + parser.add_argument('--assume-in-merge', action='store_true') + args = parser.parse_args(argv) + + if not is_in_merge() and not args.assume_in_merge: + return 0 + + retcode = 0 + for filename in args.filenames: + with open(filename, 'rb') as inputfile: + for i, line in enumerate(inputfile, start=1): + for pattern in CONFLICT_PATTERNS: + if line.startswith(pattern): + print( + f'{filename}:{i}: Merge conflict string ' + f'{pattern.strip().decode()!r} found', + ) + retcode = 1 + + return retcode + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_shebang_scripts_are_executable.py b/pre_commit_hooks/check_shebang_scripts_are_executable.py new file mode 100644 index 0000000..621696c --- /dev/null +++ b/pre_commit_hooks/check_shebang_scripts_are_executable.py @@ -0,0 +1,54 @@ +"""Check that text files with a shebang are executable.""" +from __future__ import annotations + +import argparse +import shlex +import sys +from typing import Sequence + +from pre_commit_hooks.check_executables_have_shebangs import EXECUTABLE_VALUES +from pre_commit_hooks.check_executables_have_shebangs import git_ls_files +from pre_commit_hooks.check_executables_have_shebangs import has_shebang + + +def check_shebangs(paths: list[str]) -> int: + # Cannot optimize on non-executability here if we intend this check to + # work on win32 -- and that's where problems caused by non-executability + # (elsewhere) are most likely to arise from. + return _check_git_filemode(paths) + + +def _check_git_filemode(paths: Sequence[str]) -> int: + seen: set[str] = set() + for ls_file in git_ls_files(paths): + is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:]) + if not is_executable and has_shebang(ls_file.filename): + _message(ls_file.filename) + seen.add(ls_file.filename) + + return int(bool(seen)) + + +def _message(path: str) -> None: + print( + f'{path}: has a shebang but is not marked executable!\n' + f' If it is supposed to be executable, try: ' + f'`chmod +x {shlex.quote(path)}`\n' + f' If on Windows, you may also need to: ' + f'`git add --chmod=+x {shlex.quote(path)}`\n' + f' If it not supposed to be executable, double-check its shebang ' + f'is wanted.\n', + file=sys.stderr, + ) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('filenames', nargs='*') + args = parser.parse_args(argv) + + return check_shebangs(args.filenames) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py new file mode 100644 index 0000000..a85c82a --- /dev/null +++ b/pre_commit_hooks/check_symlinks.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import argparse +import os.path +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser(description='Checks for broken symlinks.') + parser.add_argument('filenames', nargs='*', help='Filenames to check') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + if ( + os.path.islink(filename) and + not os.path.exists(filename) + ): # pragma: no cover (symlink support required) + print(f'{filename}: Broken symlink') + retv = 1 + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_toml.py b/pre_commit_hooks/check_toml.py new file mode 100644 index 0000000..0407371 --- /dev/null +++ b/pre_commit_hooks/check_toml.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import argparse +import sys +from typing import Sequence + +if sys.version_info >= (3, 11): # pragma: >=3.11 cover + import tomllib +else: # pragma: <3.11 cover + import tomli as tomllib + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check.') + args = parser.parse_args(argv) + + retval = 0 + for filename in args.filenames: + try: + with open(filename, mode='rb') as fp: + tomllib.load(fp) + except tomllib.TOMLDecodeError as exc: + print(f'{filename}: {exc}') + retval = 1 + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py new file mode 100644 index 0000000..68639bd --- /dev/null +++ b/pre_commit_hooks/check_vcs_permalinks.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import argparse +import re +import sys +from typing import Pattern +from typing import Sequence + + +def _get_pattern(domain: str) -> Pattern[bytes]: + regex = ( + rf'https://{domain}/[^/ ]+/[^/ ]+/blob/' + r'(?![a-fA-F0-9]{4,64}/)([^/. ]+)/[^# ]+#L\d+' + ) + return re.compile(regex.encode()) + + +def _check_filename(filename: str, patterns: list[Pattern[bytes]]) -> int: + retv = 0 + with open(filename, 'rb') as f: + for i, line in enumerate(f, 1): + for pattern in patterns: + if pattern.search(line): + sys.stdout.write(f'{filename}:{i}:') + sys.stdout.flush() + sys.stdout.buffer.write(line) + retv = 1 + return retv + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + parser.add_argument( + '--additional-github-domain', + dest='additional_github_domains', + action='append', + default=['github.com'], + ) + args = parser.parse_args(argv) + + patterns = [ + _get_pattern(domain) + for domain in args.additional_github_domains + ] + + retv = 0 + + for filename in args.filenames: + retv |= _check_filename(filename, patterns) + + if retv: + print() + print('Non-permanent github link detected.') + print('On any page on github press [y] to load a permalink.') + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py new file mode 100644 index 0000000..c256af9 --- /dev/null +++ b/pre_commit_hooks/check_xml.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import argparse +import xml.sax.handler +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='XML filenames to check.') + args = parser.parse_args(argv) + + retval = 0 + handler = xml.sax.handler.ContentHandler() + for filename in args.filenames: + try: + with open(filename, 'rb') as xml_file: + xml.sax.parse(xml_file, handler) + except xml.sax.SAXException as exc: + print(f'{filename}: Failed to xml parse ({exc})') + retval = 1 + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py new file mode 100644 index 0000000..9563347 --- /dev/null +++ b/pre_commit_hooks/check_yaml.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import argparse +from typing import Any +from typing import Generator +from typing import NamedTuple +from typing import Sequence + +import ruamel.yaml + +yaml = ruamel.yaml.YAML(typ='safe') + + +def _exhaust(gen: Generator[str, None, None]) -> None: + for _ in gen: + pass + + +def _parse_unsafe(*args: Any, **kwargs: Any) -> None: + _exhaust(yaml.parse(*args, **kwargs)) + + +def _load_all(*args: Any, **kwargs: Any) -> None: + _exhaust(yaml.load_all(*args, **kwargs)) + + +class Key(NamedTuple): + multi: bool + unsafe: bool + + +LOAD_FNS = { + Key(multi=False, unsafe=False): yaml.load, + Key(multi=False, unsafe=True): _parse_unsafe, + Key(multi=True, unsafe=False): _load_all, + Key(multi=True, unsafe=True): _parse_unsafe, +} + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + '-m', '--multi', '--allow-multiple-documents', action='store_true', + ) + parser.add_argument( + '--unsafe', action='store_true', + help=( + 'Instead of loading the files, simply parse them for syntax. ' + 'A syntax-only check enables extensions and unsafe constructs ' + 'which would otherwise be forbidden. Using this option removes ' + 'all guarantees of portability to other yaml implementations. ' + 'Implies --allow-multiple-documents' + ), + ) + parser.add_argument('filenames', nargs='*', help='Filenames to check.') + args = parser.parse_args(argv) + + load_fn = LOAD_FNS[Key(multi=args.multi, unsafe=args.unsafe)] + + retval = 0 + for filename in args.filenames: + try: + with open(filename, encoding='UTF-8') as f: + load_fn(f) + except ruamel.yaml.YAMLError as exc: + print(exc) + retval = 1 + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py new file mode 100644 index 0000000..cf544c7 --- /dev/null +++ b/pre_commit_hooks/debug_statement_hook.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import argparse +import ast +import traceback +from typing import NamedTuple +from typing import Sequence + + +DEBUG_STATEMENTS = { + 'bpdb', + 'ipdb', + 'pdb', + 'pdbr', + 'pudb', + 'pydevd_pycharm', + 'q', + 'rdb', + 'rpdb', + 'wdb', +} + + +class Debug(NamedTuple): + line: int + col: int + name: str + reason: str + + +class DebugStatementParser(ast.NodeVisitor): + def __init__(self) -> None: + self.breakpoints: list[Debug] = [] + + def visit_Import(self, node: ast.Import) -> None: + for name in node.names: + if name.name in DEBUG_STATEMENTS: + st = Debug(node.lineno, node.col_offset, name.name, 'imported') + self.breakpoints.append(st) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module in DEBUG_STATEMENTS: + st = Debug(node.lineno, node.col_offset, node.module, 'imported') + self.breakpoints.append(st) + + def visit_Call(self, node: ast.Call) -> None: + """python3.7+ breakpoint()""" + if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': + st = Debug(node.lineno, node.col_offset, node.func.id, 'called') + self.breakpoints.append(st) + self.generic_visit(node) + + +def check_file(filename: str) -> int: + try: + with open(filename, 'rb') as f: + ast_obj = ast.parse(f.read(), filename=filename) + except SyntaxError: + print(f'{filename} - Could not parse ast') + print() + print('\t' + traceback.format_exc().replace('\n', '\n\t')) + print() + return 1 + + visitor = DebugStatementParser() + visitor.visit(ast_obj) + + for bp in visitor.breakpoints: + print(f'{filename}:{bp.line}:{bp.col}: {bp.name} {bp.reason}') + + return int(bool(visitor.breakpoints)) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to run') + args = parser.parse_args(argv) + + retv = 0 + for filename in args.filenames: + retv |= check_file(filename) + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/destroyed_symlinks.py b/pre_commit_hooks/destroyed_symlinks.py new file mode 100644 index 0000000..f256908 --- /dev/null +++ b/pre_commit_hooks/destroyed_symlinks.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import argparse +import shlex +import subprocess +from typing import Sequence + +from pre_commit_hooks.util import cmd_output +from pre_commit_hooks.util import zsplit + +ORDINARY_CHANGED_ENTRIES_MARKER = '1' +PERMS_LINK = '120000' +PERMS_NONEXIST = '000000' + + +def find_destroyed_symlinks(files: Sequence[str]) -> list[str]: + destroyed_links: list[str] = [] + if not files: + return destroyed_links + for line in zsplit( + cmd_output('git', 'status', '--porcelain=v2', '-z', '--', *files), + ): + splitted = line.split(' ') + if splitted and splitted[0] == ORDINARY_CHANGED_ENTRIES_MARKER: + # https://git-scm.com/docs/git-status#_changed_tracked_entries + ( + _, _, _, + mode_HEAD, + mode_index, + _, + hash_HEAD, + hash_index, + *path_splitted, + ) = splitted + path = ' '.join(path_splitted) + if ( + mode_HEAD == PERMS_LINK and + mode_index != PERMS_LINK and + mode_index != PERMS_NONEXIST + ): + if hash_HEAD == hash_index: + # if old and new hashes are equal, it's not needed to check + # anything more, we've found a destroyed symlink for sure + destroyed_links.append(path) + else: + # if old and new hashes are *not* equal, it doesn't mean + # that everything is OK - new file may be altered + # by something like trailing-whitespace and/or + # mixed-line-ending hooks so we need to go deeper + SIZE_CMD = ('git', 'cat-file', '-s') + size_index = int(cmd_output(*SIZE_CMD, hash_index).strip()) + size_HEAD = int(cmd_output(*SIZE_CMD, hash_HEAD).strip()) + + # in the worst case new file may have CRLF added + # so check content only if new file is bigger + # not more than 2 bytes compared to the old one + if size_index <= size_HEAD + 2: + head_content = subprocess.check_output( + ('git', 'cat-file', '-p', hash_HEAD), + ).rstrip() + index_content = subprocess.check_output( + ('git', 'cat-file', '-p', hash_index), + ).rstrip() + if head_content == index_content: + destroyed_links.append(path) + return destroyed_links + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check.') + args = parser.parse_args(argv) + destroyed_links = find_destroyed_symlinks(files=args.filenames) + if destroyed_links: + print('Destroyed symlinks:') + for destroyed_link in destroyed_links: + print(f'- {destroyed_link}') + print('You should unstage affected files:') + print(f'\tgit reset HEAD -- {shlex.join(destroyed_links)}') + print( + 'And retry commit. As a long term solution ' + 'you may try to explicitly tell git that your ' + 'environment does not support symlinks:', + ) + print('\tgit config core.symlinks false') + return 1 + else: + return 0 + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py new file mode 100644 index 0000000..4f59d9c --- /dev/null +++ b/pre_commit_hooks/detect_aws_credentials.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import argparse +import configparser +import os +from typing import NamedTuple +from typing import Sequence + + +class BadFile(NamedTuple): + filename: str + key: str + + +def get_aws_cred_files_from_env() -> set[str]: + """Extract credential file paths from environment variables.""" + return { + os.environ[env_var] + for env_var in ( + 'AWS_CONFIG_FILE', 'AWS_CREDENTIAL_FILE', + 'AWS_SHARED_CREDENTIALS_FILE', 'BOTO_CONFIG', + ) + if env_var in os.environ + } + + +def get_aws_secrets_from_env() -> set[str]: + """Extract AWS secrets from environment variables.""" + keys = set() + for env_var in ( + 'AWS_SECRET_ACCESS_KEY', 'AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN', + ): + if os.environ.get(env_var): + keys.add(os.environ[env_var]) + return keys + + +def get_aws_secrets_from_file(credentials_file: str) -> set[str]: + """Extract AWS secrets from configuration files. + + Read an ini-style configuration file and return a set with all found AWS + secret access keys. + """ + aws_credentials_file_path = os.path.expanduser(credentials_file) + if not os.path.exists(aws_credentials_file_path): + return set() + + parser = configparser.ConfigParser() + try: + parser.read(aws_credentials_file_path) + except configparser.MissingSectionHeaderError: + return set() + + keys = set() + for section in parser.sections(): + for var in ( + 'aws_secret_access_key', 'aws_security_token', + 'aws_session_token', + ): + try: + key = parser.get(section, var).strip() + if key: + keys.add(key) + except configparser.NoOptionError: + pass + return keys + + +def check_file_for_aws_keys( + filenames: Sequence[str], + keys: set[bytes], +) -> list[BadFile]: + """Check if files contain AWS secrets. + + Return a list of all files containing AWS secrets and keys found, with all + but the first four characters obfuscated to ease debugging. + """ + bad_files = [] + + for filename in filenames: + with open(filename, 'rb') as content: + text_body = content.read() + for key in keys: + # naively match the entire file, low chance of incorrect + # collision + if key in text_body: + key_hidden = key.decode()[:4].ljust(28, '*') + bad_files.append(BadFile(filename, key_hidden)) + return bad_files + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='+', help='Filenames to run') + parser.add_argument( + '--credentials-file', + dest='credentials_file', + action='append', + default=[ + '~/.aws/config', '~/.aws/credentials', '/etc/boto.cfg', '~/.boto', + ], + help=( + 'Location of additional AWS credential file from which to get ' + 'secret keys. Can be passed multiple times.' + ), + ) + parser.add_argument( + '--allow-missing-credentials', + dest='allow_missing_credentials', + action='store_true', + help='Allow hook to pass when no credentials are detected.', + ) + args = parser.parse_args(argv) + + credential_files = set(args.credentials_file) + + # Add the credentials files configured via environment variables to the set + # of files to to gather AWS secrets from. + credential_files |= get_aws_cred_files_from_env() + + keys: set[str] = set() + for credential_file in credential_files: + keys |= get_aws_secrets_from_file(credential_file) + + # Secrets might be part of environment variables, so add such secrets to + # the set of keys. + keys |= get_aws_secrets_from_env() + + if not keys and args.allow_missing_credentials: + return 0 + + if not keys: + print( + 'No AWS keys were found in the configured credential files and ' + 'environment variables.\nPlease ensure you have the correct ' + 'setting for --credentials-file', + ) + return 2 + + keys_b = {key.encode() for key in keys} + bad_filenames = check_file_for_aws_keys(args.filenames, keys_b) + if bad_filenames: + for bad_file in bad_filenames: + print(f'AWS secret found in {bad_file.filename}: {bad_file.key}') + return 1 + else: + return 0 + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py new file mode 100644 index 0000000..cd51f90 --- /dev/null +++ b/pre_commit_hooks/detect_private_key.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import argparse +from typing import Sequence + +BLACKLIST = [ + b'BEGIN RSA PRIVATE KEY', + b'BEGIN DSA PRIVATE KEY', + b'BEGIN EC PRIVATE KEY', + b'BEGIN OPENSSH PRIVATE KEY', + b'BEGIN PRIVATE KEY', + b'PuTTY-User-Key-File-2', + b'BEGIN SSH2 ENCRYPTED PRIVATE KEY', + b'BEGIN PGP PRIVATE KEY BLOCK', + b'BEGIN ENCRYPTED PRIVATE KEY', + b'BEGIN OpenVPN Static key V1', +] + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check') + args = parser.parse_args(argv) + + private_key_files = [] + + for filename in args.filenames: + with open(filename, 'rb') as f: + content = f.read() + if any(line in content for line in BLACKLIST): + private_key_files.append(filename) + + if private_key_files: + for private_key_file in private_key_files: + print(f'Private key found: {private_key_file}') + return 1 + else: + return 0 + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py new file mode 100644 index 0000000..a30dce9 --- /dev/null +++ b/pre_commit_hooks/end_of_file_fixer.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import argparse +import os +from typing import IO +from typing import Sequence + + +def fix_file(file_obj: IO[bytes]) -> int: + # Test for newline at end of file + # Empty files will throw IOError here + try: + file_obj.seek(-1, os.SEEK_END) + except OSError: + return 0 + last_character = file_obj.read(1) + # last_character will be '' for an empty file + if last_character not in {b'\n', b'\r'} and last_character != b'': + # Needs this seek for windows, otherwise IOError + file_obj.seek(0, os.SEEK_END) + file_obj.write(b'\n') + return 1 + + while last_character in {b'\n', b'\r'}: + # Deal with the beginning of the file + if file_obj.tell() == 1: + # If we've reached the beginning of the file and it is all + # linebreaks then we can make this file empty + file_obj.seek(0) + file_obj.truncate() + return 1 + + # Go back two bytes and read a character + file_obj.seek(-2, os.SEEK_CUR) + last_character = file_obj.read(1) + + # Our current position is at the end of the file just before any amount of + # newlines. If we find extraneous newlines, then backtrack and trim them. + position = file_obj.tell() + remaining = file_obj.read() + for sequence in (b'\n', b'\r\n', b'\r'): + if remaining == sequence: + return 0 + elif remaining.startswith(sequence): + file_obj.seek(position + len(sequence)) + file_obj.truncate() + return 1 + + return 0 + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + # Read as binary so we can read byte-by-byte + with open(filename, 'rb+') as file_obj: + ret_for_file = fix_file(file_obj) + if ret_for_file: + print(f'Fixing {filename}') + retv |= ret_for_file + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py new file mode 100644 index 0000000..02bdbcc --- /dev/null +++ b/pre_commit_hooks/file_contents_sorter.py @@ -0,0 +1,88 @@ +""" +A very simple pre-commit hook that, when passed one or more filenames +as arguments, will sort the lines in those files. + +An example use case for this: you have a deploy-allowlist.txt file +in a repo that contains a list of filenames that is used to specify +files to be included in a docker container. This file has one filename +per line. Various users are adding/removing lines from this file; using +this hook on that file should reduce the instances of git merge +conflicts and keep the file nicely ordered. +""" +from __future__ import annotations + +import argparse +from typing import Any +from typing import Callable +from typing import IO +from typing import Iterable +from typing import Sequence + +PASS = 0 +FAIL = 1 + + +def sort_file_contents( + f: IO[bytes], + key: Callable[[bytes], Any] | None, + *, + unique: bool = False, +) -> int: + before = list(f) + lines: Iterable[bytes] = ( + line.rstrip(b'\n\r') for line in before if line.strip() + ) + if unique: + lines = set(lines) + after = sorted(lines, key=key) + + before_string = b''.join(before) + after_string = b'\n'.join(after) + + if after_string: + after_string += b'\n' + + if before_string == after_string: + return PASS + else: + f.seek(0) + f.write(after_string) + f.truncate() + return FAIL + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='+', help='Files to sort') + parser.add_argument( + '--ignore-case', + action='store_const', + const=bytes.lower, + default=None, + help='fold lower case to upper case characters', + ) + parser.add_argument( + '--unique', + action='store_true', + help='ensure each line is unique', + ) + args = parser.parse_args(argv) + + retv = PASS + + for arg in args.filenames: + with open(arg, 'rb+') as file_obj: + ret_for_file = sort_file_contents( + file_obj, key=args.ignore_case, unique=args.unique, + ) + + if ret_for_file: + print(f'Sorting {arg}') + + retv |= ret_for_file + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/fix_byte_order_marker.py b/pre_commit_hooks/fix_byte_order_marker.py new file mode 100644 index 0000000..22a4990 --- /dev/null +++ b/pre_commit_hooks/fix_byte_order_marker.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import argparse +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to check') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + with open(filename, 'rb') as f_b: + bts = f_b.read(3) + + if bts == b'\xef\xbb\xbf': + with open(filename, newline='', encoding='utf-8-sig') as f: + contents = f.read() + with open(filename, 'w', newline='', encoding='utf-8') as f: + f.write(contents) + + print(f'{filename}: removed byte-order marker') + retv = 1 + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py new file mode 100644 index 0000000..60c71ee --- /dev/null +++ b/pre_commit_hooks/fix_encoding_pragma.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import argparse +from typing import IO +from typing import NamedTuple +from typing import Sequence + +DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-' + + +def has_coding(line: bytes) -> bool: + if not line.strip(): + return False + return ( + line.lstrip()[:1] == b'#' and ( + b'unicode' in line or + b'encoding' in line or + b'coding:' in line or + b'coding=' in line + ) + ) + + +class ExpectedContents(NamedTuple): + shebang: bytes + rest: bytes + # True: has exactly the coding pragma expected + # False: missing coding pragma entirely + # None: has a coding pragma, but it does not match + pragma_status: bool | None + ending: bytes + + @property + def has_any_pragma(self) -> bool: + return self.pragma_status is not False + + def is_expected_pragma(self, remove: bool) -> bool: + expected_pragma_status = not remove + return self.pragma_status is expected_pragma_status + + +def _get_expected_contents( + first_line: bytes, + second_line: bytes, + rest: bytes, + expected_pragma: bytes, +) -> ExpectedContents: + ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n' + + if first_line.startswith(b'#!'): + shebang = first_line + potential_coding = second_line + else: + shebang = b'' + potential_coding = first_line + rest = second_line + rest + + if potential_coding.rstrip(b'\r\n') == expected_pragma: + pragma_status: bool | None = True + elif has_coding(potential_coding): + pragma_status = None + else: + pragma_status = False + rest = potential_coding + rest + + return ExpectedContents( + shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending, + ) + + +def fix_encoding_pragma( + f: IO[bytes], + remove: bool = False, + expected_pragma: bytes = DEFAULT_PRAGMA, +) -> int: + expected = _get_expected_contents( + f.readline(), f.readline(), f.read(), expected_pragma, + ) + + # Special cases for empty files + if not expected.rest.strip(): + # If a file only has a shebang or a coding pragma, remove it + if expected.has_any_pragma or expected.shebang: + f.seek(0) + f.truncate() + f.write(b'') + return 1 + else: + return 0 + + if expected.is_expected_pragma(remove): + return 0 + + # Otherwise, write out the new file + f.seek(0) + f.truncate() + f.write(expected.shebang) + if not remove: + f.write(expected_pragma + expected.ending) + f.write(expected.rest) + + return 1 + + +def _normalize_pragma(pragma: str) -> bytes: + return pragma.encode().rstrip() + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser( + 'Fixes the encoding pragma of python files', + ) + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + parser.add_argument( + '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma, + help=( + f'The encoding pragma to use. ' + f'Default: {DEFAULT_PRAGMA.decode()}' + ), + ) + parser.add_argument( + '--remove', action='store_true', + help='Remove the encoding pragma (Useful in a python3-only codebase)', + ) + args = parser.parse_args(argv) + + retv = 0 + + if args.remove: + fmt = 'Removed encoding pragma from {filename}' + else: + fmt = 'Added `{pragma}` to {filename}' + + for filename in args.filenames: + with open(filename, 'r+b') as f: + file_ret = fix_encoding_pragma( + f, remove=args.remove, expected_pragma=args.pragma, + ) + retv |= file_ret + if file_ret: + print( + fmt.format(pragma=args.pragma.decode(), filename=filename), + ) + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py new file mode 100644 index 0000000..b806cad --- /dev/null +++ b/pre_commit_hooks/forbid_new_submodules.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import argparse +import os +from typing import Sequence + +from pre_commit_hooks.util import cmd_output + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + args = parser.parse_args(argv) + + if ( + 'PRE_COMMIT_FROM_REF' in os.environ and + 'PRE_COMMIT_TO_REF' in os.environ + ): + diff_arg = '...'.join(( + os.environ['PRE_COMMIT_FROM_REF'], + os.environ['PRE_COMMIT_TO_REF'], + )) + else: + diff_arg = '--staged' + added_diff = cmd_output( + 'git', 'diff', '--diff-filter=A', '--raw', diff_arg, '--', + *args.filenames, + ) + retv = 0 + for line in added_diff.splitlines(): + metadata, filename = line.split('\t', 1) + new_mode = metadata.split(' ')[1] + if new_mode == '160000': + print(f'{filename}: new submodule introduced') + retv = 1 + + if retv: + print() + print('This commit introduces new submodules.') + print('Did you unintentionally `git add .`?') + print('To fix: git rm {thesubmodule} # no trailing slash') + print('Also check .gitmodules') + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py new file mode 100644 index 0000000..0328e86 --- /dev/null +++ b/pre_commit_hooks/mixed_line_ending.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import argparse +import collections +from typing import Sequence + + +CRLF = b'\r\n' +LF = b'\n' +CR = b'\r' +# Prefer LF to CRLF to CR, but detect CRLF before LF +ALL_ENDINGS = (CR, CRLF, LF) +FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF} + + +def _fix(filename: str, contents: bytes, ending: bytes) -> None: + new_contents = b''.join( + line.rstrip(b'\r\n') + ending for line in contents.splitlines(True) + ) + with open(filename, 'wb') as f: + f.write(new_contents) + + +def fix_filename(filename: str, fix: str) -> int: + with open(filename, 'rb') as f: + contents = f.read() + + counts: dict[bytes, int] = collections.defaultdict(int) + + for line in contents.splitlines(True): + for ending in ALL_ENDINGS: + if line.endswith(ending): + counts[ending] += 1 + break + + # Some amount of mixed line endings + mixed = sum(bool(x) for x in counts.values()) > 1 + + if fix == 'no' or (fix == 'auto' and not mixed): + return mixed + + if fix == 'auto': + max_ending = LF + max_lines = 0 + # ordering is important here such that lf > crlf > cr + for ending_type in ALL_ENDINGS: + # also important, using >= to find a max that prefers the last + if counts[ending_type] >= max_lines: + max_ending = ending_type + max_lines = counts[ending_type] + + _fix(filename, contents, max_ending) + return 1 + else: + target_ending = FIX_TO_LINE_ENDING[fix] + # find if there are lines with *other* endings + # It's possible there's no line endings of the target type + counts.pop(target_ending, None) + other_endings = bool(sum(counts.values())) + if other_endings: + _fix(filename, contents, target_ending) + return other_endings + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + '-f', '--fix', + choices=('auto', 'no') + tuple(FIX_TO_LINE_ENDING), + default='auto', + help='Replace line ending with the specified. Default is "auto"', + ) + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + for filename in args.filenames: + if fix_filename(filename, args.fix): + if args.fix == 'no': + print(f'{filename}: mixed line endings') + else: + print(f'{filename}: fixed mixed line endings') + retv = 1 + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py new file mode 100644 index 0000000..741f726 --- /dev/null +++ b/pre_commit_hooks/no_commit_to_branch.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import argparse +import re +from typing import AbstractSet +from typing import Sequence + +from pre_commit_hooks.util import CalledProcessError +from pre_commit_hooks.util import cmd_output + + +def is_on_branch( + protected: AbstractSet[str], + patterns: AbstractSet[str] = frozenset(), +) -> bool: + try: + ref_name = cmd_output('git', 'symbolic-ref', 'HEAD') + except CalledProcessError: + return False + chunks = ref_name.strip().split('/') + branch_name = '/'.join(chunks[2:]) + return branch_name in protected or any( + re.match(p, branch_name) for p in patterns + ) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + '-b', '--branch', action='append', + help='branch to disallow commits to, may be specified multiple times', + ) + parser.add_argument( + '-p', '--pattern', action='append', + help=( + 'regex pattern for branch name to disallow commits to, ' + 'may be specified multiple times' + ), + ) + args = parser.parse_args(argv) + + protected = frozenset(args.branch or ('master', 'main')) + patterns = frozenset(args.pattern or ()) + return int(is_on_branch(protected, patterns)) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py new file mode 100644 index 0000000..627a11c --- /dev/null +++ b/pre_commit_hooks/pretty_format_json.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import argparse +import json +import sys +from difflib import unified_diff +from typing import Mapping +from typing import Sequence + + +def _get_pretty_format( + contents: str, + indent: str, + ensure_ascii: bool = True, + sort_keys: bool = True, + top_keys: Sequence[str] = (), +) -> str: + def pairs_first(pairs: Sequence[tuple[str, str]]) -> Mapping[str, str]: + before = [pair for pair in pairs if pair[0] in top_keys] + before = sorted(before, key=lambda x: top_keys.index(x[0])) + after = [pair for pair in pairs if pair[0] not in top_keys] + if sort_keys: + after.sort() + return dict(before + after) + json_pretty = json.dumps( + json.loads(contents, object_pairs_hook=pairs_first), + indent=indent, + ensure_ascii=ensure_ascii, + ) + return f'{json_pretty}\n' + + +def _autofix(filename: str, new_contents: str) -> None: + print(f'Fixing file {filename}') + with open(filename, 'w', encoding='UTF-8') as f: + f.write(new_contents) + + +def parse_num_to_int(s: str) -> int | str: + """Convert string numbers to int, leaving strings as is.""" + try: + return int(s) + except ValueError: + return s + + +def parse_topkeys(s: str) -> list[str]: + return s.split(',') + + +def get_diff(source: str, target: str, file: str) -> str: + source_lines = source.splitlines(True) + target_lines = target.splitlines(True) + diff = unified_diff(source_lines, target_lines, fromfile=file, tofile=file) + return ''.join(diff) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + '--autofix', + action='store_true', + dest='autofix', + help='Automatically fixes encountered not-pretty-formatted files', + ) + parser.add_argument( + '--indent', + type=parse_num_to_int, + default='2', + help=( + 'The number of indent spaces or a string to be used as delimiter' + ' for indentation level e.g. 4 or "\t" (Default: 2)' + ), + ) + parser.add_argument( + '--no-ensure-ascii', + action='store_true', + dest='no_ensure_ascii', + default=False, + help=( + 'Do NOT convert non-ASCII characters to Unicode escape sequences ' + '(\\uXXXX)' + ), + ) + parser.add_argument( + '--no-sort-keys', + action='store_true', + dest='no_sort_keys', + default=False, + help='Keep JSON nodes in the same order', + ) + parser.add_argument( + '--top-keys', + type=parse_topkeys, + dest='top_keys', + default=[], + help='Ordered list of keys to keep at the top of JSON hashes', + ) + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + status = 0 + + for json_file in args.filenames: + with open(json_file, encoding='UTF-8') as f: + contents = f.read() + + try: + pretty_contents = _get_pretty_format( + contents, args.indent, ensure_ascii=not args.no_ensure_ascii, + sort_keys=not args.no_sort_keys, top_keys=args.top_keys, + ) + except ValueError: + print( + f'Input File {json_file} is not a valid JSON, consider using ' + f'check-json', + ) + return 1 + + if contents != pretty_contents: + if args.autofix: + _autofix(json_file, pretty_contents) + else: + diff_output = get_diff(contents, pretty_contents, json_file) + sys.stdout.buffer.write(diff_output.encode()) + + status = 1 + + return status + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/removed.py b/pre_commit_hooks/removed.py new file mode 100644 index 0000000..6f6c7b7 --- /dev/null +++ b/pre_commit_hooks/removed.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import sys +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + argv = argv if argv is not None else sys.argv[1:] + hookid, new_hookid, url = argv[:3] + raise SystemExit( + f'`{hookid}` has been removed -- use `{new_hookid}` from {url}', + ) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py new file mode 100644 index 0000000..5884394 --- /dev/null +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import argparse +import re +from typing import IO +from typing import Sequence + + +PASS = 0 +FAIL = 1 + + +class Requirement: + UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') + UNTIL_SEP = re.compile(rb'[^;\s]+') + + def __init__(self) -> None: + self.value: bytes | None = None + self.comments: list[bytes] = [] + + @property + def name(self) -> bytes: + assert self.value is not None, self.value + name = self.value.lower() + for egg in (b'#egg=', b'&egg='): + if egg in self.value: + return name.partition(egg)[-1] + + m = self.UNTIL_SEP.match(name) + assert m is not None + + name = m.group() + m = self.UNTIL_COMPARISON.search(name) + if not m: + return name + + return name[:m.start()] + + def __lt__(self, requirement: Requirement) -> bool: + # \n means top of file comment, so always return True, + # otherwise just do a string comparison with value. + assert self.value is not None, self.value + if self.value == b'\n': + return True + elif requirement.value == b'\n': + return False + else: + return self.name < requirement.name + + def is_complete(self) -> bool: + return ( + self.value is not None and + not self.value.rstrip(b'\r\n').endswith(b'\\') + ) + + def append_value(self, value: bytes) -> None: + if self.value is not None: + self.value += value + else: + self.value = value + + +def fix_requirements(f: IO[bytes]) -> int: + requirements: list[Requirement] = [] + before = list(f) + after: list[bytes] = [] + + before_string = b''.join(before) + + # adds new line in case one is missing + # AND a change to the requirements file is needed regardless: + if before and not before[-1].endswith(b'\n'): + before[-1] += b'\n' + + # If the file is empty (i.e. only whitespace/newlines) exit early + if before_string.strip() == b'': + return PASS + + for line in before: + # If the most recent requirement object has a value, then it's + # time to start building the next requirement object. + + if not len(requirements) or requirements[-1].is_complete(): + requirements.append(Requirement()) + + requirement = requirements[-1] + + # If we see a newline before any requirements, then this is a + # top of file comment. + if len(requirements) == 1 and line.strip() == b'': + if ( + len(requirement.comments) and + requirement.comments[0].startswith(b'#') + ): + requirement.value = b'\n' + else: + requirement.comments.append(line) + elif line.lstrip().startswith(b'#') or line.strip() == b'': + requirement.comments.append(line) + else: + requirement.append_value(line) + + # if a file ends in a comment, preserve it at the end + if requirements[-1].value is None: + rest = requirements.pop().comments + else: + rest = [] + + # find and remove pkg-resources==0.0.0 + # which is automatically added by broken pip package under Debian + requirements = [ + req for req in requirements + if req.value != b'pkg-resources==0.0.0\n' + ] + + for requirement in sorted(requirements): + after.extend(requirement.comments) + assert requirement.value, requirement.value + after.append(requirement.value) + after.extend(rest) + + after_string = b''.join(after) + + if before_string == after_string: + return PASS + else: + f.seek(0) + f.write(after_string) + f.truncate() + return FAIL + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = PASS + + for arg in args.filenames: + with open(arg, 'rb+') as file_obj: + ret_for_file = fix_requirements(file_obj) + + if ret_for_file: + print(f'Sorting {arg}') + + retv |= ret_for_file + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py new file mode 100644 index 0000000..116b5c1 --- /dev/null +++ b/pre_commit_hooks/sort_simple_yaml.py @@ -0,0 +1,125 @@ +"""Sort a simple YAML file, keeping blocks of comments and definitions +together. + +We assume a strict subset of YAML that looks like: + + # block of header comments + # here that should always + # be at the top of the file + + # optional comments + # can go here + key: value + key: value + + key: value + +In other words, we don't sort deeper than the top layer, and might corrupt +complicated YAML files. +""" +from __future__ import annotations + +import argparse +from typing import Sequence + + +QUOTES = ["'", '"'] + + +def sort(lines: list[str]) -> list[str]: + """Sort a YAML file in alphabetical order, keeping blocks together. + + :param lines: array of strings (without newlines) + :return: sorted array of strings + """ + # make a copy of lines since we will clobber it + lines = list(lines) + new_lines = parse_block(lines, header=True) + + for block in sorted(parse_blocks(lines), key=first_key): + if new_lines: + new_lines.append('') + new_lines.extend(block) + + return new_lines + + +def parse_block(lines: list[str], header: bool = False) -> list[str]: + """Parse and return a single block, popping off the start of `lines`. + + If parsing a header block, we stop after we reach a line that is not a + comment. Otherwise, we stop after reaching an empty line. + + :param lines: list of lines + :param header: whether we are parsing a header block + :return: list of lines that form the single block + """ + block_lines = [] + while lines and lines[0] and (not header or lines[0].startswith('#')): + block_lines.append(lines.pop(0)) + return block_lines + + +def parse_blocks(lines: list[str]) -> list[list[str]]: + """Parse and return all possible blocks, popping off the start of `lines`. + + :param lines: list of lines + :return: list of blocks, where each block is a list of lines + """ + blocks = [] + + while lines: + if lines[0] == '': + lines.pop(0) + else: + blocks.append(parse_block(lines)) + + return blocks + + +def first_key(lines: list[str]) -> str: + """Returns a string representing the sort key of a block. + + The sort key is the first YAML key we encounter, ignoring comments, and + stripping leading quotes. + + >>> print(test) + # some comment + 'foo': true + >>> first_key(test) + 'foo' + """ + for line in lines: + if line.startswith('#'): + continue + if any(line.startswith(quote) for quote in QUOTES): + return line[1:] + return line + else: + return '' # not actually reached in reality + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retval = 0 + + for filename in args.filenames: + with open(filename, 'r+') as f: + lines = [line.rstrip() for line in f.readlines()] + new_lines = sort(lines) + + if lines != new_lines: + print(f'Fixing file `{filename}`') + f.seek(0) + f.write('\n'.join(new_lines) + '\n') + f.truncate() + retval = 1 + + return retval + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py new file mode 100644 index 0000000..d1b1c4a --- /dev/null +++ b/pre_commit_hooks/string_fixer.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import argparse +import io +import re +import sys +import tokenize +from typing import Sequence + +if sys.version_info >= (3, 12): # pragma: >=3.12 cover + FSTRING_START = tokenize.FSTRING_START + FSTRING_END = tokenize.FSTRING_END +else: # pragma: <3.12 cover + FSTRING_START = FSTRING_END = -1 + +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') + + +def handle_match(token_text: str) -> str: + if '"""' in token_text or "'''" in token_text: + return token_text + + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text + + +def get_line_offsets_by_line_no(src: str) -> list[int]: + # Padded so we can index with line number + offsets = [-1, 0] + for line in src.splitlines(True): + offsets.append(offsets[-1] + len(line)) + return offsets + + +def fix_strings(filename: str) -> int: + with open(filename, encoding='UTF-8', newline='') as f: + contents = f.read() + line_offsets = get_line_offsets_by_line_no(contents) + + # Basically a mutable string + splitcontents = list(contents) + + fstring_depth = 0 + + # Iterate in reverse so the offsets are always correct + tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline)) + tokens = reversed(tokens_l) + for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens: + if token_type == FSTRING_START: # pragma: >=3.12 cover + fstring_depth += 1 + elif token_type == FSTRING_END: # pragma: >=3.12 cover + fstring_depth -= 1 + elif fstring_depth == 0 and token_type == tokenize.STRING: + new_text = handle_match(token_text) + splitcontents[ + line_offsets[srow] + scol: + line_offsets[erow] + ecol + ] = new_text + + new_contents = ''.join(splitcontents) + if contents != new_contents: + with open(filename, 'w', encoding='UTF-8', newline='') as f: + f.write(new_contents) + return 1 + else: + return 0 + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + return_value = fix_strings(filename) + if return_value != 0: + print(f'Fixing strings in {filename}') + retv |= return_value + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py new file mode 100644 index 0000000..e7842af --- /dev/null +++ b/pre_commit_hooks/tests_should_end_in_test.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import argparse +import os.path +import re +from typing import Sequence + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*') + mutex = parser.add_mutually_exclusive_group() + mutex.add_argument( + '--pytest', + dest='pattern', + action='store_const', + const=r'.*_test\.py', + default=r'.*_test\.py', + help='(the default) ensure tests match %(const)s', + ) + mutex.add_argument( + '--pytest-test-first', + dest='pattern', + action='store_const', + const=r'test_.*\.py', + help='ensure tests match %(const)s', + ) + mutex.add_argument( + '--django', '--unittest', + dest='pattern', + action='store_const', + const=r'test.*\.py', + help='ensure tests match %(const)s', + ) + args = parser.parse_args(argv) + + retcode = 0 + reg = re.compile(args.pattern) + for filename in args.filenames: + base = os.path.basename(filename) + if ( + not reg.fullmatch(base) and + not base == '__init__.py' and + not base == 'conftest.py' + ): + retcode = 1 + print(f'{filename} does not match pattern "{args.pattern}"') + + return retcode + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py new file mode 100644 index 0000000..84f5067 --- /dev/null +++ b/pre_commit_hooks/trailing_whitespace_fixer.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import argparse +import os +from typing import Sequence + + +def _fix_file( + filename: str, + is_markdown: bool, + chars: bytes | None, +) -> bool: + with open(filename, mode='rb') as file_processed: + lines = file_processed.readlines() + newlines = [_process_line(line, is_markdown, chars) for line in lines] + if newlines != lines: + with open(filename, mode='wb') as file_processed: + for line in newlines: + file_processed.write(line) + return True + else: + return False + + +def _process_line( + line: bytes, + is_markdown: bool, + chars: bytes | None, +) -> bytes: + if line[-2:] == b'\r\n': + eol = b'\r\n' + line = line[:-2] + elif line[-1:] == b'\n': + eol = b'\n' + line = line[:-1] + else: + eol = b'' + # preserve trailing two-space for non-blank lines in markdown files + if is_markdown and (not line.isspace()) and line.endswith(b' '): + return line[:-2].rstrip(chars) + b' ' + eol + return line.rstrip(chars) + eol + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + '--no-markdown-linebreak-ext', + action='store_true', + help=argparse.SUPPRESS, + ) + parser.add_argument( + '--markdown-linebreak-ext', + action='append', + default=[], + metavar='*|EXT[,EXT,...]', + help=( + 'Markdown extensions (or *) to not strip linebreak spaces. ' + 'default: %(default)s' + ), + ) + parser.add_argument( + '--chars', + help=( + 'The set of characters to strip from the end of lines. ' + 'Defaults to all whitespace characters.' + ), + ) + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + if args.no_markdown_linebreak_ext: + print('--no-markdown-linebreak-ext now does nothing!') + + md_args = args.markdown_linebreak_ext + if '' in md_args: + parser.error('--markdown-linebreak-ext requires a non-empty argument') + all_markdown = '*' in md_args + # normalize extensions; split at ',', lowercase, and force 1 leading '.' + md_exts = [ + '.' + x.lower().lstrip('.') for x in ','.join(md_args).split(',') + ] + + # reject probable "eaten" filename as extension: skip leading '.' with [1:] + for ext in md_exts: + if any(c in ext[1:] for c in r'./\:'): + parser.error( + f'bad --markdown-linebreak-ext extension ' + f'{ext!r} (has . / \\ :)\n' + f" (probably filename; use '--markdown-linebreak-ext=EXT')", + ) + chars = None if args.chars is None else args.chars.encode() + return_code = 0 + for filename in args.filenames: + _, extension = os.path.splitext(filename.lower()) + md = all_markdown or extension in md_exts + if _fix_file(filename, md, chars): + print(f'Fixing {filename}') + return_code = 1 + return return_code + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py new file mode 100644 index 0000000..d6c90ae --- /dev/null +++ b/pre_commit_hooks/util.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import subprocess +from typing import Any + + +class CalledProcessError(RuntimeError): + pass + + +def added_files() -> set[str]: + cmd = ('git', 'diff', '--staged', '--name-only', '--diff-filter=A') + return set(cmd_output(*cmd).splitlines()) + + +def cmd_output(*cmd: str, retcode: int | None = 0, **kwargs: Any) -> str: + kwargs.setdefault('stdout', subprocess.PIPE) + kwargs.setdefault('stderr', subprocess.PIPE) + proc = subprocess.Popen(cmd, **kwargs) + stdout, stderr = proc.communicate() + stdout = stdout.decode() + if retcode is not None and proc.returncode != retcode: + raise CalledProcessError(cmd, retcode, proc.returncode, stdout, stderr) + return stdout + + +def zsplit(s: str) -> list[str]: + s = s.strip('\0') + if s: + return s.split('\0') + else: + return [] |