summaryrefslogtreecommitdiffstats
path: root/pre_commit_hooks
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pre_commit_hooks/__init__.py0
-rw-r--r--pre_commit_hooks/check_added_large_files.py81
-rw-r--r--pre_commit_hooks/check_ast.py33
-rw-r--r--pre_commit_hooks/check_builtin_literals.py105
-rw-r--r--pre_commit_hooks/check_byte_order_marker.py24
-rw-r--r--pre_commit_hooks/check_case_conflict.py72
-rw-r--r--pre_commit_hooks/check_docstring_first.py61
-rw-r--r--pre_commit_hooks/check_executables_have_shebangs.py85
-rw-r--r--pre_commit_hooks/check_json.py38
-rw-r--r--pre_commit_hooks/check_merge_conflict.py56
-rw-r--r--pre_commit_hooks/check_shebang_scripts_are_executable.py54
-rw-r--r--pre_commit_hooks/check_symlinks.py27
-rw-r--r--pre_commit_hooks/check_toml.py30
-rw-r--r--pre_commit_hooks/check_vcs_permalinks.py60
-rw-r--r--pre_commit_hooks/check_xml.py26
-rw-r--r--pre_commit_hooks/check_yaml.py72
-rw-r--r--pre_commit_hooks/debug_statement_hook.py86
-rw-r--r--pre_commit_hooks/destroyed_symlinks.py92
-rw-r--r--pre_commit_hooks/detect_aws_credentials.py151
-rw-r--r--pre_commit_hooks/detect_private_key.py42
-rw-r--r--pre_commit_hooks/end_of_file_fixer.py71
-rw-r--r--pre_commit_hooks/file_contents_sorter.py88
-rw-r--r--pre_commit_hooks/fix_byte_order_marker.py31
-rw-r--r--pre_commit_hooks/fix_encoding_pragma.py149
-rw-r--r--pre_commit_hooks/forbid_new_submodules.py48
-rw-r--r--pre_commit_hooks/mixed_line_ending.py88
-rw-r--r--pre_commit_hooks/no_commit_to_branch.py48
-rw-r--r--pre_commit_hooks/pretty_format_json.py133
-rw-r--r--pre_commit_hooks/removed.py16
-rw-r--r--pre_commit_hooks/requirements_txt_fixer.py153
-rw-r--r--pre_commit_hooks/sort_simple_yaml.py125
-rw-r--r--pre_commit_hooks/string_fixer.py93
-rw-r--r--pre_commit_hooks/tests_should_end_in_test.py53
-rw-r--r--pre_commit_hooks/trailing_whitespace_fixer.py103
-rw-r--r--pre_commit_hooks/util.py32
35 files changed, 2426 insertions, 0 deletions
diff --git a/pre_commit_hooks/__init__.py b/pre_commit_hooks/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pre_commit_hooks/__init__.py
diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py
new file mode 100644
index 0000000..9e0619b
--- /dev/null
+++ b/pre_commit_hooks/check_added_large_files.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+import argparse
+import math
+import os
+import subprocess
+from typing import Sequence
+
+from pre_commit_hooks.util import added_files
+from pre_commit_hooks.util import zsplit
+
+
+def filter_lfs_files(filenames: set[str]) -> None: # pragma: no cover (lfs)
+ """Remove files tracked by git-lfs from the set."""
+ if not filenames:
+ return
+
+ check_attr = subprocess.run(
+ ('git', 'check-attr', 'filter', '-z', '--stdin'),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ encoding='utf-8',
+ check=True,
+ input='\0'.join(filenames),
+ )
+ stdout = zsplit(check_attr.stdout)
+ for i in range(0, len(stdout), 3):
+ filename, filter_tag = stdout[i], stdout[i + 2]
+ if filter_tag == 'lfs':
+ filenames.remove(filename)
+
+
+def find_large_added_files(
+ filenames: Sequence[str],
+ maxkb: int,
+ *,
+ enforce_all: bool = False,
+) -> int:
+ # Find all added files that are also in the list of files pre-commit tells
+ # us about
+ retv = 0
+ filenames_filtered = set(filenames)
+ filter_lfs_files(filenames_filtered)
+
+ if not enforce_all:
+ filenames_filtered &= added_files()
+
+ for filename in filenames_filtered:
+ kb = math.ceil(os.stat(filename).st_size / 1024)
+ if kb > maxkb:
+ print(f'{filename} ({kb} KB) exceeds {maxkb} KB.')
+ retv = 1
+
+ return retv
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'filenames', nargs='*',
+ help='Filenames pre-commit believes are changed.',
+ )
+ parser.add_argument(
+ '--enforce-all', action='store_true',
+ help='Enforce all files are checked, not just staged files.',
+ )
+ parser.add_argument(
+ '--maxkb', type=int, default=500,
+ help='Maximum allowable KB for added files',
+ )
+ args = parser.parse_args(argv)
+
+ return find_large_added_files(
+ args.filenames,
+ args.maxkb,
+ enforce_all=args.enforce_all,
+ )
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py
new file mode 100644
index 0000000..fdac361
--- /dev/null
+++ b/pre_commit_hooks/check_ast.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import platform
+import sys
+import traceback
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ args = parser.parse_args(argv)
+
+ retval = 0
+ for filename in args.filenames:
+
+ try:
+ with open(filename, 'rb') as f:
+ ast.parse(f.read(), filename=filename)
+ except SyntaxError:
+ impl = platform.python_implementation()
+ version = sys.version.split()[0]
+ print(f'{filename}: failed parsing with {impl} {version}:')
+ tb = ' ' + traceback.format_exc().replace('\n', '\n ')
+ print(f'\n{tb}')
+ retval = 1
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py
new file mode 100644
index 0000000..d3054aa
--- /dev/null
+++ b/pre_commit_hooks/check_builtin_literals.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+import argparse
+import ast
+from typing import NamedTuple
+from typing import Sequence
+
+
+BUILTIN_TYPES = {
+ 'complex': '0j',
+ 'dict': '{}',
+ 'float': '0.0',
+ 'int': '0',
+ 'list': '[]',
+ 'str': "''",
+ 'tuple': '()',
+}
+
+
+class Call(NamedTuple):
+ name: str
+ line: int
+ column: int
+
+
+class Visitor(ast.NodeVisitor):
+ def __init__(
+ self,
+ ignore: Sequence[str] | None = None,
+ allow_dict_kwargs: bool = True,
+ ) -> None:
+ self.builtin_type_calls: list[Call] = []
+ self.ignore = set(ignore) if ignore else set()
+ self.allow_dict_kwargs = allow_dict_kwargs
+
+ def _check_dict_call(self, node: ast.Call) -> bool:
+ return self.allow_dict_kwargs and bool(node.keywords)
+
+ def visit_Call(self, node: ast.Call) -> None:
+ if not isinstance(node.func, ast.Name):
+ # Ignore functions that are object attributes (`foo.bar()`).
+ # Assume that if the user calls `builtins.list()`, they know what
+ # they're doing.
+ return
+ if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore):
+ return
+ if node.func.id == 'dict' and self._check_dict_call(node):
+ return
+ elif node.args:
+ return
+ self.builtin_type_calls.append(
+ Call(node.func.id, node.lineno, node.col_offset),
+ )
+
+
+def check_file(
+ filename: str,
+ ignore: Sequence[str] | None = None,
+ allow_dict_kwargs: bool = True,
+) -> list[Call]:
+ with open(filename, 'rb') as f:
+ tree = ast.parse(f.read(), filename=filename)
+ visitor = Visitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
+ visitor.visit(tree)
+ return visitor.builtin_type_calls
+
+
+def parse_ignore(value: str) -> set[str]:
+ return set(value.split(','))
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ parser.add_argument('--ignore', type=parse_ignore, default=set())
+
+ mutex = parser.add_mutually_exclusive_group(required=False)
+ mutex.add_argument('--allow-dict-kwargs', action='store_true')
+ mutex.add_argument(
+ '--no-allow-dict-kwargs',
+ dest='allow_dict_kwargs', action='store_false',
+ )
+ mutex.set_defaults(allow_dict_kwargs=True)
+
+ args = parser.parse_args(argv)
+
+ rc = 0
+ for filename in args.filenames:
+ calls = check_file(
+ filename,
+ ignore=args.ignore,
+ allow_dict_kwargs=args.allow_dict_kwargs,
+ )
+ if calls:
+ rc = rc or 1
+ for call in calls:
+ print(
+ f'{filename}:{call.line}:{call.column}: '
+ f'replace {call.name}() with {BUILTIN_TYPES[call.name]}',
+ )
+ return rc
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py
new file mode 100644
index 0000000..59cc561
--- /dev/null
+++ b/pre_commit_hooks/check_byte_order_marker.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+import argparse
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ with open(filename, 'rb') as f:
+ if f.read(3) == b'\xef\xbb\xbf':
+ retv = 1
+ print(f'{filename}: Has a byte-order marker')
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py
new file mode 100644
index 0000000..33a13f1
--- /dev/null
+++ b/pre_commit_hooks/check_case_conflict.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import argparse
+from typing import Iterable
+from typing import Iterator
+from typing import Sequence
+
+from pre_commit_hooks.util import added_files
+from pre_commit_hooks.util import cmd_output
+
+
+def lower_set(iterable: Iterable[str]) -> set[str]:
+ return {x.lower() for x in iterable}
+
+
+def parents(file: str) -> Iterator[str]:
+ path_parts = file.split('/')
+ path_parts.pop()
+ while path_parts:
+ yield '/'.join(path_parts)
+ path_parts.pop()
+
+
+def directories_for(files: set[str]) -> set[str]:
+ return {parent for file in files for parent in parents(file)}
+
+
+def find_conflicting_filenames(filenames: Sequence[str]) -> int:
+ repo_files = set(cmd_output('git', 'ls-files').splitlines())
+ repo_files |= directories_for(repo_files)
+ relevant_files = set(filenames) | added_files()
+ relevant_files |= directories_for(relevant_files)
+ repo_files -= relevant_files
+ retv = 0
+
+ # new file conflicts with existing file
+ conflicts = lower_set(repo_files) & lower_set(relevant_files)
+
+ # new file conflicts with other new file
+ lowercase_relevant_files = lower_set(relevant_files)
+ for filename in set(relevant_files):
+ if filename.lower() in lowercase_relevant_files:
+ lowercase_relevant_files.remove(filename.lower())
+ else:
+ conflicts.add(filename.lower())
+
+ if conflicts:
+ conflicting_files = [
+ x for x in repo_files | relevant_files
+ if x.lower() in conflicts
+ ]
+ for filename in sorted(conflicting_files):
+ print(f'Case-insensitivity conflict found: {filename}')
+ retv = 1
+
+ return retv
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'filenames', nargs='*',
+ help='Filenames pre-commit believes are changed.',
+ )
+
+ args = parser.parse_args(argv)
+
+ return find_conflicting_filenames(args.filenames)
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py
new file mode 100644
index 0000000..d55f08a
--- /dev/null
+++ b/pre_commit_hooks/check_docstring_first.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import argparse
+import io
+import tokenize
+from tokenize import tokenize as tokenize_tokenize
+from typing import Sequence
+
+NON_CODE_TOKENS = frozenset((
+ tokenize.COMMENT, tokenize.ENDMARKER, tokenize.NEWLINE, tokenize.NL,
+ tokenize.ENCODING,
+))
+
+
+def check_docstring_first(src: bytes, filename: str = '<unknown>') -> int:
+ """Returns nonzero if the source has what looks like a docstring that is
+ not at the beginning of the source.
+
+ A string will be considered a docstring if it is a STRING token with a
+ col offset of 0.
+ """
+ found_docstring_line = None
+ found_code_line = None
+
+ tok_gen = tokenize_tokenize(io.BytesIO(src).readline)
+ for tok_type, _, (sline, scol), _, _ in tok_gen:
+ # Looks like a docstring!
+ if tok_type == tokenize.STRING and scol == 0:
+ if found_docstring_line is not None:
+ print(
+ f'{filename}:{sline}: Multiple module docstrings '
+ f'(first docstring on line {found_docstring_line}).',
+ )
+ return 1
+ elif found_code_line is not None:
+ print(
+ f'{filename}:{sline}: Module docstring appears after code '
+ f'(code seen on line {found_code_line}).',
+ )
+ return 1
+ else:
+ found_docstring_line = sline
+ elif tok_type not in NON_CODE_TOKENS and found_code_line is None:
+ found_code_line = sline
+
+ return 0
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ with open(filename, 'rb') as f:
+ contents = f.read()
+ retv |= check_docstring_first(contents, filename=filename)
+
+ return retv
diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py
new file mode 100644
index 0000000..d8e4f49
--- /dev/null
+++ b/pre_commit_hooks/check_executables_have_shebangs.py
@@ -0,0 +1,85 @@
+"""Check that executable text files have a shebang."""
+from __future__ import annotations
+
+import argparse
+import shlex
+import sys
+from typing import Generator
+from typing import NamedTuple
+from typing import Sequence
+
+from pre_commit_hooks.util import cmd_output
+from pre_commit_hooks.util import zsplit
+
+EXECUTABLE_VALUES = frozenset(('1', '3', '5', '7'))
+
+
+def check_executables(paths: list[str]) -> int:
+ fs_tracks_executable_bit = cmd_output(
+ 'git', 'config', 'core.fileMode', retcode=None,
+ ).strip()
+ if fs_tracks_executable_bit == 'false': # pragma: win32 cover
+ return _check_git_filemode(paths)
+ else: # pragma: win32 no cover
+ retv = 0
+ for path in paths:
+ if not has_shebang(path):
+ _message(path)
+ retv = 1
+
+ return retv
+
+
+class GitLsFile(NamedTuple):
+ mode: str
+ filename: str
+
+
+def git_ls_files(paths: Sequence[str]) -> Generator[GitLsFile, None, None]:
+ outs = cmd_output('git', 'ls-files', '-z', '--stage', '--', *paths)
+ for out in zsplit(outs):
+ metadata, filename = out.split('\t')
+ mode, _, _ = metadata.split()
+ yield GitLsFile(mode, filename)
+
+
+def _check_git_filemode(paths: Sequence[str]) -> int:
+ seen: set[str] = set()
+ for ls_file in git_ls_files(paths):
+ is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:])
+ if is_executable and not has_shebang(ls_file.filename):
+ _message(ls_file.filename)
+ seen.add(ls_file.filename)
+
+ return int(bool(seen))
+
+
+def has_shebang(path: str) -> int:
+ with open(path, 'rb') as f:
+ first_bytes = f.read(2)
+
+ return first_bytes == b'#!'
+
+
+def _message(path: str) -> None:
+ print(
+ f'{path}: marked executable but has no (or invalid) shebang!\n'
+ f" If it isn't supposed to be executable, try: "
+ f'`chmod -x {shlex.quote(path)}`\n'
+ f' If on Windows, you may also need to: '
+ f'`git add --chmod=-x {shlex.quote(path)}`\n'
+ f' If it is supposed to be executable, double-check its shebang.',
+ file=sys.stderr,
+ )
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument('filenames', nargs='*')
+ args = parser.parse_args(argv)
+
+ return check_executables(args.filenames)
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py
new file mode 100644
index 0000000..6a679fe
--- /dev/null
+++ b/pre_commit_hooks/check_json.py
@@ -0,0 +1,38 @@
+from __future__ import annotations
+
+import argparse
+import json
+from typing import Any
+from typing import Sequence
+
+
+def raise_duplicate_keys(
+ ordered_pairs: list[tuple[str, Any]],
+) -> dict[str, Any]:
+ d = {}
+ for key, val in ordered_pairs:
+ if key in d:
+ raise ValueError(f'Duplicate key: {key}')
+ else:
+ d[key] = val
+ return d
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check.')
+ args = parser.parse_args(argv)
+
+ retval = 0
+ for filename in args.filenames:
+ with open(filename, 'rb') as f:
+ try:
+ json.load(f, object_pairs_hook=raise_duplicate_keys)
+ except ValueError as exc:
+ print(f'{filename}: Failed to json decode ({exc})')
+ retval = 1
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py
new file mode 100644
index 0000000..15ec284
--- /dev/null
+++ b/pre_commit_hooks/check_merge_conflict.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+import argparse
+import os.path
+from typing import Sequence
+
+from pre_commit_hooks.util import cmd_output
+
+
+CONFLICT_PATTERNS = [
+ b'<<<<<<< ',
+ b'======= ',
+ b'=======\r\n',
+ b'=======\n',
+ b'>>>>>>> ',
+]
+
+
+def is_in_merge() -> bool:
+ git_dir = cmd_output('git', 'rev-parse', '--git-dir').rstrip()
+ return (
+ os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and
+ (
+ os.path.exists(os.path.join(git_dir, 'MERGE_HEAD')) or
+ os.path.exists(os.path.join(git_dir, 'rebase-apply')) or
+ os.path.exists(os.path.join(git_dir, 'rebase-merge'))
+ )
+ )
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ parser.add_argument('--assume-in-merge', action='store_true')
+ args = parser.parse_args(argv)
+
+ if not is_in_merge() and not args.assume_in_merge:
+ return 0
+
+ retcode = 0
+ for filename in args.filenames:
+ with open(filename, 'rb') as inputfile:
+ for i, line in enumerate(inputfile, start=1):
+ for pattern in CONFLICT_PATTERNS:
+ if line.startswith(pattern):
+ print(
+ f'{filename}:{i}: Merge conflict string '
+ f'{pattern.strip().decode()!r} found',
+ )
+ retcode = 1
+
+ return retcode
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_shebang_scripts_are_executable.py b/pre_commit_hooks/check_shebang_scripts_are_executable.py
new file mode 100644
index 0000000..621696c
--- /dev/null
+++ b/pre_commit_hooks/check_shebang_scripts_are_executable.py
@@ -0,0 +1,54 @@
+"""Check that text files with a shebang are executable."""
+from __future__ import annotations
+
+import argparse
+import shlex
+import sys
+from typing import Sequence
+
+from pre_commit_hooks.check_executables_have_shebangs import EXECUTABLE_VALUES
+from pre_commit_hooks.check_executables_have_shebangs import git_ls_files
+from pre_commit_hooks.check_executables_have_shebangs import has_shebang
+
+
+def check_shebangs(paths: list[str]) -> int:
+ # Cannot optimize on non-executability here if we intend this check to
+ # work on win32 -- and that's where problems caused by non-executability
+ # (elsewhere) are most likely to arise from.
+ return _check_git_filemode(paths)
+
+
+def _check_git_filemode(paths: Sequence[str]) -> int:
+ seen: set[str] = set()
+ for ls_file in git_ls_files(paths):
+ is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:])
+ if not is_executable and has_shebang(ls_file.filename):
+ _message(ls_file.filename)
+ seen.add(ls_file.filename)
+
+ return int(bool(seen))
+
+
+def _message(path: str) -> None:
+ print(
+ f'{path}: has a shebang but is not marked executable!\n'
+ f' If it is supposed to be executable, try: '
+ f'`chmod +x {shlex.quote(path)}`\n'
+ f' If on Windows, you may also need to: '
+ f'`git add --chmod=+x {shlex.quote(path)}`\n'
+ f' If it not supposed to be executable, double-check its shebang '
+ f'is wanted.\n',
+ file=sys.stderr,
+ )
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument('filenames', nargs='*')
+ args = parser.parse_args(argv)
+
+ return check_shebangs(args.filenames)
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py
new file mode 100644
index 0000000..a85c82a
--- /dev/null
+++ b/pre_commit_hooks/check_symlinks.py
@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+import argparse
+import os.path
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ if (
+ os.path.islink(filename) and
+ not os.path.exists(filename)
+ ): # pragma: no cover (symlink support required)
+ print(f'{filename}: Broken symlink')
+ retv = 1
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_toml.py b/pre_commit_hooks/check_toml.py
new file mode 100644
index 0000000..0407371
--- /dev/null
+++ b/pre_commit_hooks/check_toml.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import argparse
+import sys
+from typing import Sequence
+
+if sys.version_info >= (3, 11): # pragma: >=3.11 cover
+ import tomllib
+else: # pragma: <3.11 cover
+ import tomli as tomllib
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check.')
+ args = parser.parse_args(argv)
+
+ retval = 0
+ for filename in args.filenames:
+ try:
+ with open(filename, mode='rb') as fp:
+ tomllib.load(fp)
+ except tomllib.TOMLDecodeError as exc:
+ print(f'{filename}: {exc}')
+ retval = 1
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py
new file mode 100644
index 0000000..68639bd
--- /dev/null
+++ b/pre_commit_hooks/check_vcs_permalinks.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import argparse
+import re
+import sys
+from typing import Pattern
+from typing import Sequence
+
+
+def _get_pattern(domain: str) -> Pattern[bytes]:
+ regex = (
+ rf'https://{domain}/[^/ ]+/[^/ ]+/blob/'
+ r'(?![a-fA-F0-9]{4,64}/)([^/. ]+)/[^# ]+#L\d+'
+ )
+ return re.compile(regex.encode())
+
+
+def _check_filename(filename: str, patterns: list[Pattern[bytes]]) -> int:
+ retv = 0
+ with open(filename, 'rb') as f:
+ for i, line in enumerate(f, 1):
+ for pattern in patterns:
+ if pattern.search(line):
+ sys.stdout.write(f'{filename}:{i}:')
+ sys.stdout.flush()
+ sys.stdout.buffer.write(line)
+ retv = 1
+ return retv
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ parser.add_argument(
+ '--additional-github-domain',
+ dest='additional_github_domains',
+ action='append',
+ default=['github.com'],
+ )
+ args = parser.parse_args(argv)
+
+ patterns = [
+ _get_pattern(domain)
+ for domain in args.additional_github_domains
+ ]
+
+ retv = 0
+
+ for filename in args.filenames:
+ retv |= _check_filename(filename, patterns)
+
+ if retv:
+ print()
+ print('Non-permanent github link detected.')
+ print('On any page on github press [y] to load a permalink.')
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py
new file mode 100644
index 0000000..c256af9
--- /dev/null
+++ b/pre_commit_hooks/check_xml.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+import argparse
+import xml.sax.handler
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
+ args = parser.parse_args(argv)
+
+ retval = 0
+ handler = xml.sax.handler.ContentHandler()
+ for filename in args.filenames:
+ try:
+ with open(filename, 'rb') as xml_file:
+ xml.sax.parse(xml_file, handler)
+ except xml.sax.SAXException as exc:
+ print(f'{filename}: Failed to xml parse ({exc})')
+ retval = 1
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py
new file mode 100644
index 0000000..9563347
--- /dev/null
+++ b/pre_commit_hooks/check_yaml.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import argparse
+from typing import Any
+from typing import Generator
+from typing import NamedTuple
+from typing import Sequence
+
+import ruamel.yaml
+
+yaml = ruamel.yaml.YAML(typ='safe')
+
+
+def _exhaust(gen: Generator[str, None, None]) -> None:
+ for _ in gen:
+ pass
+
+
+def _parse_unsafe(*args: Any, **kwargs: Any) -> None:
+ _exhaust(yaml.parse(*args, **kwargs))
+
+
+def _load_all(*args: Any, **kwargs: Any) -> None:
+ _exhaust(yaml.load_all(*args, **kwargs))
+
+
+class Key(NamedTuple):
+ multi: bool
+ unsafe: bool
+
+
+LOAD_FNS = {
+ Key(multi=False, unsafe=False): yaml.load,
+ Key(multi=False, unsafe=True): _parse_unsafe,
+ Key(multi=True, unsafe=False): _load_all,
+ Key(multi=True, unsafe=True): _parse_unsafe,
+}
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-m', '--multi', '--allow-multiple-documents', action='store_true',
+ )
+ parser.add_argument(
+ '--unsafe', action='store_true',
+ help=(
+ 'Instead of loading the files, simply parse them for syntax. '
+ 'A syntax-only check enables extensions and unsafe constructs '
+ 'which would otherwise be forbidden. Using this option removes '
+ 'all guarantees of portability to other yaml implementations. '
+ 'Implies --allow-multiple-documents'
+ ),
+ )
+ parser.add_argument('filenames', nargs='*', help='Filenames to check.')
+ args = parser.parse_args(argv)
+
+ load_fn = LOAD_FNS[Key(multi=args.multi, unsafe=args.unsafe)]
+
+ retval = 0
+ for filename in args.filenames:
+ try:
+ with open(filename, encoding='UTF-8') as f:
+ load_fn(f)
+ except ruamel.yaml.YAMLError as exc:
+ print(exc)
+ retval = 1
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py
new file mode 100644
index 0000000..cf544c7
--- /dev/null
+++ b/pre_commit_hooks/debug_statement_hook.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import traceback
+from typing import NamedTuple
+from typing import Sequence
+
+
+DEBUG_STATEMENTS = {
+ 'bpdb',
+ 'ipdb',
+ 'pdb',
+ 'pdbr',
+ 'pudb',
+ 'pydevd_pycharm',
+ 'q',
+ 'rdb',
+ 'rpdb',
+ 'wdb',
+}
+
+
+class Debug(NamedTuple):
+ line: int
+ col: int
+ name: str
+ reason: str
+
+
+class DebugStatementParser(ast.NodeVisitor):
+ def __init__(self) -> None:
+ self.breakpoints: list[Debug] = []
+
+ def visit_Import(self, node: ast.Import) -> None:
+ for name in node.names:
+ if name.name in DEBUG_STATEMENTS:
+ st = Debug(node.lineno, node.col_offset, name.name, 'imported')
+ self.breakpoints.append(st)
+
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
+ if node.module in DEBUG_STATEMENTS:
+ st = Debug(node.lineno, node.col_offset, node.module, 'imported')
+ self.breakpoints.append(st)
+
+ def visit_Call(self, node: ast.Call) -> None:
+ """python3.7+ breakpoint()"""
+ if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
+ st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
+ self.breakpoints.append(st)
+ self.generic_visit(node)
+
+
+def check_file(filename: str) -> int:
+ try:
+ with open(filename, 'rb') as f:
+ ast_obj = ast.parse(f.read(), filename=filename)
+ except SyntaxError:
+ print(f'{filename} - Could not parse ast')
+ print()
+ print('\t' + traceback.format_exc().replace('\n', '\n\t'))
+ print()
+ return 1
+
+ visitor = DebugStatementParser()
+ visitor.visit(ast_obj)
+
+ for bp in visitor.breakpoints:
+ print(f'{filename}:{bp.line}:{bp.col}: {bp.name} {bp.reason}')
+
+ return int(bool(visitor.breakpoints))
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to run')
+ args = parser.parse_args(argv)
+
+ retv = 0
+ for filename in args.filenames:
+ retv |= check_file(filename)
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/destroyed_symlinks.py b/pre_commit_hooks/destroyed_symlinks.py
new file mode 100644
index 0000000..f256908
--- /dev/null
+++ b/pre_commit_hooks/destroyed_symlinks.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+import argparse
+import shlex
+import subprocess
+from typing import Sequence
+
+from pre_commit_hooks.util import cmd_output
+from pre_commit_hooks.util import zsplit
+
+ORDINARY_CHANGED_ENTRIES_MARKER = '1'
+PERMS_LINK = '120000'
+PERMS_NONEXIST = '000000'
+
+
+def find_destroyed_symlinks(files: Sequence[str]) -> list[str]:
+ destroyed_links: list[str] = []
+ if not files:
+ return destroyed_links
+ for line in zsplit(
+ cmd_output('git', 'status', '--porcelain=v2', '-z', '--', *files),
+ ):
+ splitted = line.split(' ')
+ if splitted and splitted[0] == ORDINARY_CHANGED_ENTRIES_MARKER:
+ # https://git-scm.com/docs/git-status#_changed_tracked_entries
+ (
+ _, _, _,
+ mode_HEAD,
+ mode_index,
+ _,
+ hash_HEAD,
+ hash_index,
+ *path_splitted,
+ ) = splitted
+ path = ' '.join(path_splitted)
+ if (
+ mode_HEAD == PERMS_LINK and
+ mode_index != PERMS_LINK and
+ mode_index != PERMS_NONEXIST
+ ):
+ if hash_HEAD == hash_index:
+ # if old and new hashes are equal, it's not needed to check
+ # anything more, we've found a destroyed symlink for sure
+ destroyed_links.append(path)
+ else:
+ # if old and new hashes are *not* equal, it doesn't mean
+ # that everything is OK - new file may be altered
+ # by something like trailing-whitespace and/or
+ # mixed-line-ending hooks so we need to go deeper
+ SIZE_CMD = ('git', 'cat-file', '-s')
+ size_index = int(cmd_output(*SIZE_CMD, hash_index).strip())
+ size_HEAD = int(cmd_output(*SIZE_CMD, hash_HEAD).strip())
+
+ # in the worst case new file may have CRLF added
+ # so check content only if new file is bigger
+ # not more than 2 bytes compared to the old one
+ if size_index <= size_HEAD + 2:
+ head_content = subprocess.check_output(
+ ('git', 'cat-file', '-p', hash_HEAD),
+ ).rstrip()
+ index_content = subprocess.check_output(
+ ('git', 'cat-file', '-p', hash_index),
+ ).rstrip()
+ if head_content == index_content:
+ destroyed_links.append(path)
+ return destroyed_links
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check.')
+ args = parser.parse_args(argv)
+ destroyed_links = find_destroyed_symlinks(files=args.filenames)
+ if destroyed_links:
+ print('Destroyed symlinks:')
+ for destroyed_link in destroyed_links:
+ print(f'- {destroyed_link}')
+ print('You should unstage affected files:')
+ print(f'\tgit reset HEAD -- {shlex.join(destroyed_links)}')
+ print(
+ 'And retry commit. As a long term solution '
+ 'you may try to explicitly tell git that your '
+ 'environment does not support symlinks:',
+ )
+ print('\tgit config core.symlinks false')
+ return 1
+ else:
+ return 0
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py
new file mode 100644
index 0000000..4f59d9c
--- /dev/null
+++ b/pre_commit_hooks/detect_aws_credentials.py
@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+import argparse
+import configparser
+import os
+from typing import NamedTuple
+from typing import Sequence
+
+
+class BadFile(NamedTuple):
+ filename: str
+ key: str
+
+
+def get_aws_cred_files_from_env() -> set[str]:
+ """Extract credential file paths from environment variables."""
+ return {
+ os.environ[env_var]
+ for env_var in (
+ 'AWS_CONFIG_FILE', 'AWS_CREDENTIAL_FILE',
+ 'AWS_SHARED_CREDENTIALS_FILE', 'BOTO_CONFIG',
+ )
+ if env_var in os.environ
+ }
+
+
+def get_aws_secrets_from_env() -> set[str]:
+ """Extract AWS secrets from environment variables."""
+ keys = set()
+ for env_var in (
+ 'AWS_SECRET_ACCESS_KEY', 'AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN',
+ ):
+ if os.environ.get(env_var):
+ keys.add(os.environ[env_var])
+ return keys
+
+
+def get_aws_secrets_from_file(credentials_file: str) -> set[str]:
+ """Extract AWS secrets from configuration files.
+
+ Read an ini-style configuration file and return a set with all found AWS
+ secret access keys.
+ """
+ aws_credentials_file_path = os.path.expanduser(credentials_file)
+ if not os.path.exists(aws_credentials_file_path):
+ return set()
+
+ parser = configparser.ConfigParser()
+ try:
+ parser.read(aws_credentials_file_path)
+ except configparser.MissingSectionHeaderError:
+ return set()
+
+ keys = set()
+ for section in parser.sections():
+ for var in (
+ 'aws_secret_access_key', 'aws_security_token',
+ 'aws_session_token',
+ ):
+ try:
+ key = parser.get(section, var).strip()
+ if key:
+ keys.add(key)
+ except configparser.NoOptionError:
+ pass
+ return keys
+
+
+def check_file_for_aws_keys(
+ filenames: Sequence[str],
+ keys: set[bytes],
+) -> list[BadFile]:
+ """Check if files contain AWS secrets.
+
+ Return a list of all files containing AWS secrets and keys found, with all
+ but the first four characters obfuscated to ease debugging.
+ """
+ bad_files = []
+
+ for filename in filenames:
+ with open(filename, 'rb') as content:
+ text_body = content.read()
+ for key in keys:
+ # naively match the entire file, low chance of incorrect
+ # collision
+ if key in text_body:
+ key_hidden = key.decode()[:4].ljust(28, '*')
+ bad_files.append(BadFile(filename, key_hidden))
+ return bad_files
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='+', help='Filenames to run')
+ parser.add_argument(
+ '--credentials-file',
+ dest='credentials_file',
+ action='append',
+ default=[
+ '~/.aws/config', '~/.aws/credentials', '/etc/boto.cfg', '~/.boto',
+ ],
+ help=(
+ 'Location of additional AWS credential file from which to get '
+ 'secret keys. Can be passed multiple times.'
+ ),
+ )
+ parser.add_argument(
+ '--allow-missing-credentials',
+ dest='allow_missing_credentials',
+ action='store_true',
+ help='Allow hook to pass when no credentials are detected.',
+ )
+ args = parser.parse_args(argv)
+
+ credential_files = set(args.credentials_file)
+
+ # Add the credentials files configured via environment variables to the set
+ # of files to to gather AWS secrets from.
+ credential_files |= get_aws_cred_files_from_env()
+
+ keys: set[str] = set()
+ for credential_file in credential_files:
+ keys |= get_aws_secrets_from_file(credential_file)
+
+ # Secrets might be part of environment variables, so add such secrets to
+ # the set of keys.
+ keys |= get_aws_secrets_from_env()
+
+ if not keys and args.allow_missing_credentials:
+ return 0
+
+ if not keys:
+ print(
+ 'No AWS keys were found in the configured credential files and '
+ 'environment variables.\nPlease ensure you have the correct '
+ 'setting for --credentials-file',
+ )
+ return 2
+
+ keys_b = {key.encode() for key in keys}
+ bad_filenames = check_file_for_aws_keys(args.filenames, keys_b)
+ if bad_filenames:
+ for bad_file in bad_filenames:
+ print(f'AWS secret found in {bad_file.filename}: {bad_file.key}')
+ return 1
+ else:
+ return 0
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py
new file mode 100644
index 0000000..cd51f90
--- /dev/null
+++ b/pre_commit_hooks/detect_private_key.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import argparse
+from typing import Sequence
+
+BLACKLIST = [
+ b'BEGIN RSA PRIVATE KEY',
+ b'BEGIN DSA PRIVATE KEY',
+ b'BEGIN EC PRIVATE KEY',
+ b'BEGIN OPENSSH PRIVATE KEY',
+ b'BEGIN PRIVATE KEY',
+ b'PuTTY-User-Key-File-2',
+ b'BEGIN SSH2 ENCRYPTED PRIVATE KEY',
+ b'BEGIN PGP PRIVATE KEY BLOCK',
+ b'BEGIN ENCRYPTED PRIVATE KEY',
+ b'BEGIN OpenVPN Static key V1',
+]
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
+ args = parser.parse_args(argv)
+
+ private_key_files = []
+
+ for filename in args.filenames:
+ with open(filename, 'rb') as f:
+ content = f.read()
+ if any(line in content for line in BLACKLIST):
+ private_key_files.append(filename)
+
+ if private_key_files:
+ for private_key_file in private_key_files:
+ print(f'Private key found: {private_key_file}')
+ return 1
+ else:
+ return 0
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py
new file mode 100644
index 0000000..a30dce9
--- /dev/null
+++ b/pre_commit_hooks/end_of_file_fixer.py
@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import IO
+from typing import Sequence
+
+
+def fix_file(file_obj: IO[bytes]) -> int:
+ # Test for newline at end of file
+ # Empty files will throw IOError here
+ try:
+ file_obj.seek(-1, os.SEEK_END)
+ except OSError:
+ return 0
+ last_character = file_obj.read(1)
+ # last_character will be '' for an empty file
+ if last_character not in {b'\n', b'\r'} and last_character != b'':
+ # Needs this seek for windows, otherwise IOError
+ file_obj.seek(0, os.SEEK_END)
+ file_obj.write(b'\n')
+ return 1
+
+ while last_character in {b'\n', b'\r'}:
+ # Deal with the beginning of the file
+ if file_obj.tell() == 1:
+ # If we've reached the beginning of the file and it is all
+ # linebreaks then we can make this file empty
+ file_obj.seek(0)
+ file_obj.truncate()
+ return 1
+
+ # Go back two bytes and read a character
+ file_obj.seek(-2, os.SEEK_CUR)
+ last_character = file_obj.read(1)
+
+ # Our current position is at the end of the file just before any amount of
+ # newlines. If we find extraneous newlines, then backtrack and trim them.
+ position = file_obj.tell()
+ remaining = file_obj.read()
+ for sequence in (b'\n', b'\r\n', b'\r'):
+ if remaining == sequence:
+ return 0
+ elif remaining.startswith(sequence):
+ file_obj.seek(position + len(sequence))
+ file_obj.truncate()
+ return 1
+
+ return 0
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ # Read as binary so we can read byte-by-byte
+ with open(filename, 'rb+') as file_obj:
+ ret_for_file = fix_file(file_obj)
+ if ret_for_file:
+ print(f'Fixing {filename}')
+ retv |= ret_for_file
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py
new file mode 100644
index 0000000..02bdbcc
--- /dev/null
+++ b/pre_commit_hooks/file_contents_sorter.py
@@ -0,0 +1,88 @@
+"""
+A very simple pre-commit hook that, when passed one or more filenames
+as arguments, will sort the lines in those files.
+
+An example use case for this: you have a deploy-allowlist.txt file
+in a repo that contains a list of filenames that is used to specify
+files to be included in a docker container. This file has one filename
+per line. Various users are adding/removing lines from this file; using
+this hook on that file should reduce the instances of git merge
+conflicts and keep the file nicely ordered.
+"""
+from __future__ import annotations
+
+import argparse
+from typing import Any
+from typing import Callable
+from typing import IO
+from typing import Iterable
+from typing import Sequence
+
+PASS = 0
+FAIL = 1
+
+
+def sort_file_contents(
+ f: IO[bytes],
+ key: Callable[[bytes], Any] | None,
+ *,
+ unique: bool = False,
+) -> int:
+ before = list(f)
+ lines: Iterable[bytes] = (
+ line.rstrip(b'\n\r') for line in before if line.strip()
+ )
+ if unique:
+ lines = set(lines)
+ after = sorted(lines, key=key)
+
+ before_string = b''.join(before)
+ after_string = b'\n'.join(after)
+
+ if after_string:
+ after_string += b'\n'
+
+ if before_string == after_string:
+ return PASS
+ else:
+ f.seek(0)
+ f.write(after_string)
+ f.truncate()
+ return FAIL
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='+', help='Files to sort')
+ parser.add_argument(
+ '--ignore-case',
+ action='store_const',
+ const=bytes.lower,
+ default=None,
+ help='fold lower case to upper case characters',
+ )
+ parser.add_argument(
+ '--unique',
+ action='store_true',
+ help='ensure each line is unique',
+ )
+ args = parser.parse_args(argv)
+
+ retv = PASS
+
+ for arg in args.filenames:
+ with open(arg, 'rb+') as file_obj:
+ ret_for_file = sort_file_contents(
+ file_obj, key=args.ignore_case, unique=args.unique,
+ )
+
+ if ret_for_file:
+ print(f'Sorting {arg}')
+
+ retv |= ret_for_file
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/fix_byte_order_marker.py b/pre_commit_hooks/fix_byte_order_marker.py
new file mode 100644
index 0000000..22a4990
--- /dev/null
+++ b/pre_commit_hooks/fix_byte_order_marker.py
@@ -0,0 +1,31 @@
+from __future__ import annotations
+
+import argparse
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ with open(filename, 'rb') as f_b:
+ bts = f_b.read(3)
+
+ if bts == b'\xef\xbb\xbf':
+ with open(filename, newline='', encoding='utf-8-sig') as f:
+ contents = f.read()
+ with open(filename, 'w', newline='', encoding='utf-8') as f:
+ f.write(contents)
+
+ print(f'{filename}: removed byte-order marker')
+ retv = 1
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
new file mode 100644
index 0000000..60c71ee
--- /dev/null
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -0,0 +1,149 @@
+from __future__ import annotations
+
+import argparse
+from typing import IO
+from typing import NamedTuple
+from typing import Sequence
+
+DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'
+
+
+def has_coding(line: bytes) -> bool:
+ if not line.strip():
+ return False
+ return (
+ line.lstrip()[:1] == b'#' and (
+ b'unicode' in line or
+ b'encoding' in line or
+ b'coding:' in line or
+ b'coding=' in line
+ )
+ )
+
+
+class ExpectedContents(NamedTuple):
+ shebang: bytes
+ rest: bytes
+ # True: has exactly the coding pragma expected
+ # False: missing coding pragma entirely
+ # None: has a coding pragma, but it does not match
+ pragma_status: bool | None
+ ending: bytes
+
+ @property
+ def has_any_pragma(self) -> bool:
+ return self.pragma_status is not False
+
+ def is_expected_pragma(self, remove: bool) -> bool:
+ expected_pragma_status = not remove
+ return self.pragma_status is expected_pragma_status
+
+
+def _get_expected_contents(
+ first_line: bytes,
+ second_line: bytes,
+ rest: bytes,
+ expected_pragma: bytes,
+) -> ExpectedContents:
+ ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n'
+
+ if first_line.startswith(b'#!'):
+ shebang = first_line
+ potential_coding = second_line
+ else:
+ shebang = b''
+ potential_coding = first_line
+ rest = second_line + rest
+
+ if potential_coding.rstrip(b'\r\n') == expected_pragma:
+ pragma_status: bool | None = True
+ elif has_coding(potential_coding):
+ pragma_status = None
+ else:
+ pragma_status = False
+ rest = potential_coding + rest
+
+ return ExpectedContents(
+ shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending,
+ )
+
+
+def fix_encoding_pragma(
+ f: IO[bytes],
+ remove: bool = False,
+ expected_pragma: bytes = DEFAULT_PRAGMA,
+) -> int:
+ expected = _get_expected_contents(
+ f.readline(), f.readline(), f.read(), expected_pragma,
+ )
+
+ # Special cases for empty files
+ if not expected.rest.strip():
+ # If a file only has a shebang or a coding pragma, remove it
+ if expected.has_any_pragma or expected.shebang:
+ f.seek(0)
+ f.truncate()
+ f.write(b'')
+ return 1
+ else:
+ return 0
+
+ if expected.is_expected_pragma(remove):
+ return 0
+
+ # Otherwise, write out the new file
+ f.seek(0)
+ f.truncate()
+ f.write(expected.shebang)
+ if not remove:
+ f.write(expected_pragma + expected.ending)
+ f.write(expected.rest)
+
+ return 1
+
+
+def _normalize_pragma(pragma: str) -> bytes:
+ return pragma.encode().rstrip()
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser(
+ 'Fixes the encoding pragma of python files',
+ )
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ parser.add_argument(
+ '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
+ help=(
+ f'The encoding pragma to use. '
+ f'Default: {DEFAULT_PRAGMA.decode()}'
+ ),
+ )
+ parser.add_argument(
+ '--remove', action='store_true',
+ help='Remove the encoding pragma (Useful in a python3-only codebase)',
+ )
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ if args.remove:
+ fmt = 'Removed encoding pragma from {filename}'
+ else:
+ fmt = 'Added `{pragma}` to {filename}'
+
+ for filename in args.filenames:
+ with open(filename, 'r+b') as f:
+ file_ret = fix_encoding_pragma(
+ f, remove=args.remove, expected_pragma=args.pragma,
+ )
+ retv |= file_ret
+ if file_ret:
+ print(
+ fmt.format(pragma=args.pragma.decode(), filename=filename),
+ )
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py
new file mode 100644
index 0000000..b806cad
--- /dev/null
+++ b/pre_commit_hooks/forbid_new_submodules.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Sequence
+
+from pre_commit_hooks.util import cmd_output
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ args = parser.parse_args(argv)
+
+ if (
+ 'PRE_COMMIT_FROM_REF' in os.environ and
+ 'PRE_COMMIT_TO_REF' in os.environ
+ ):
+ diff_arg = '...'.join((
+ os.environ['PRE_COMMIT_FROM_REF'],
+ os.environ['PRE_COMMIT_TO_REF'],
+ ))
+ else:
+ diff_arg = '--staged'
+ added_diff = cmd_output(
+ 'git', 'diff', '--diff-filter=A', '--raw', diff_arg, '--',
+ *args.filenames,
+ )
+ retv = 0
+ for line in added_diff.splitlines():
+ metadata, filename = line.split('\t', 1)
+ new_mode = metadata.split(' ')[1]
+ if new_mode == '160000':
+ print(f'{filename}: new submodule introduced')
+ retv = 1
+
+ if retv:
+ print()
+ print('This commit introduces new submodules.')
+ print('Did you unintentionally `git add .`?')
+ print('To fix: git rm {thesubmodule} # no trailing slash')
+ print('Also check .gitmodules')
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py
new file mode 100644
index 0000000..0328e86
--- /dev/null
+++ b/pre_commit_hooks/mixed_line_ending.py
@@ -0,0 +1,88 @@
+from __future__ import annotations
+
+import argparse
+import collections
+from typing import Sequence
+
+
+CRLF = b'\r\n'
+LF = b'\n'
+CR = b'\r'
+# Prefer LF to CRLF to CR, but detect CRLF before LF
+ALL_ENDINGS = (CR, CRLF, LF)
+FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
+
+
+def _fix(filename: str, contents: bytes, ending: bytes) -> None:
+ new_contents = b''.join(
+ line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
+ )
+ with open(filename, 'wb') as f:
+ f.write(new_contents)
+
+
+def fix_filename(filename: str, fix: str) -> int:
+ with open(filename, 'rb') as f:
+ contents = f.read()
+
+ counts: dict[bytes, int] = collections.defaultdict(int)
+
+ for line in contents.splitlines(True):
+ for ending in ALL_ENDINGS:
+ if line.endswith(ending):
+ counts[ending] += 1
+ break
+
+ # Some amount of mixed line endings
+ mixed = sum(bool(x) for x in counts.values()) > 1
+
+ if fix == 'no' or (fix == 'auto' and not mixed):
+ return mixed
+
+ if fix == 'auto':
+ max_ending = LF
+ max_lines = 0
+ # ordering is important here such that lf > crlf > cr
+ for ending_type in ALL_ENDINGS:
+ # also important, using >= to find a max that prefers the last
+ if counts[ending_type] >= max_lines:
+ max_ending = ending_type
+ max_lines = counts[ending_type]
+
+ _fix(filename, contents, max_ending)
+ return 1
+ else:
+ target_ending = FIX_TO_LINE_ENDING[fix]
+ # find if there are lines with *other* endings
+ # It's possible there's no line endings of the target type
+ counts.pop(target_ending, None)
+ other_endings = bool(sum(counts.values()))
+ if other_endings:
+ _fix(filename, contents, target_ending)
+ return other_endings
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-f', '--fix',
+ choices=('auto', 'no') + tuple(FIX_TO_LINE_ENDING),
+ default='auto',
+ help='Replace line ending with the specified. Default is "auto"',
+ )
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = 0
+ for filename in args.filenames:
+ if fix_filename(filename, args.fix):
+ if args.fix == 'no':
+ print(f'{filename}: mixed line endings')
+ else:
+ print(f'{filename}: fixed mixed line endings')
+ retv = 1
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py
new file mode 100644
index 0000000..741f726
--- /dev/null
+++ b/pre_commit_hooks/no_commit_to_branch.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+import argparse
+import re
+from typing import AbstractSet
+from typing import Sequence
+
+from pre_commit_hooks.util import CalledProcessError
+from pre_commit_hooks.util import cmd_output
+
+
+def is_on_branch(
+ protected: AbstractSet[str],
+ patterns: AbstractSet[str] = frozenset(),
+) -> bool:
+ try:
+ ref_name = cmd_output('git', 'symbolic-ref', 'HEAD')
+ except CalledProcessError:
+ return False
+ chunks = ref_name.strip().split('/')
+ branch_name = '/'.join(chunks[2:])
+ return branch_name in protected or any(
+ re.match(p, branch_name) for p in patterns
+ )
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-b', '--branch', action='append',
+ help='branch to disallow commits to, may be specified multiple times',
+ )
+ parser.add_argument(
+ '-p', '--pattern', action='append',
+ help=(
+ 'regex pattern for branch name to disallow commits to, '
+ 'may be specified multiple times'
+ ),
+ )
+ args = parser.parse_args(argv)
+
+ protected = frozenset(args.branch or ('master', 'main'))
+ patterns = frozenset(args.pattern or ())
+ return int(is_on_branch(protected, patterns))
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py
new file mode 100644
index 0000000..627a11c
--- /dev/null
+++ b/pre_commit_hooks/pretty_format_json.py
@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+from difflib import unified_diff
+from typing import Mapping
+from typing import Sequence
+
+
+def _get_pretty_format(
+ contents: str,
+ indent: str,
+ ensure_ascii: bool = True,
+ sort_keys: bool = True,
+ top_keys: Sequence[str] = (),
+) -> str:
+ def pairs_first(pairs: Sequence[tuple[str, str]]) -> Mapping[str, str]:
+ before = [pair for pair in pairs if pair[0] in top_keys]
+ before = sorted(before, key=lambda x: top_keys.index(x[0]))
+ after = [pair for pair in pairs if pair[0] not in top_keys]
+ if sort_keys:
+ after.sort()
+ return dict(before + after)
+ json_pretty = json.dumps(
+ json.loads(contents, object_pairs_hook=pairs_first),
+ indent=indent,
+ ensure_ascii=ensure_ascii,
+ )
+ return f'{json_pretty}\n'
+
+
+def _autofix(filename: str, new_contents: str) -> None:
+ print(f'Fixing file {filename}')
+ with open(filename, 'w', encoding='UTF-8') as f:
+ f.write(new_contents)
+
+
+def parse_num_to_int(s: str) -> int | str:
+ """Convert string numbers to int, leaving strings as is."""
+ try:
+ return int(s)
+ except ValueError:
+ return s
+
+
+def parse_topkeys(s: str) -> list[str]:
+ return s.split(',')
+
+
+def get_diff(source: str, target: str, file: str) -> str:
+ source_lines = source.splitlines(True)
+ target_lines = target.splitlines(True)
+ diff = unified_diff(source_lines, target_lines, fromfile=file, tofile=file)
+ return ''.join(diff)
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--autofix',
+ action='store_true',
+ dest='autofix',
+ help='Automatically fixes encountered not-pretty-formatted files',
+ )
+ parser.add_argument(
+ '--indent',
+ type=parse_num_to_int,
+ default='2',
+ help=(
+ 'The number of indent spaces or a string to be used as delimiter'
+ ' for indentation level e.g. 4 or "\t" (Default: 2)'
+ ),
+ )
+ parser.add_argument(
+ '--no-ensure-ascii',
+ action='store_true',
+ dest='no_ensure_ascii',
+ default=False,
+ help=(
+ 'Do NOT convert non-ASCII characters to Unicode escape sequences '
+ '(\\uXXXX)'
+ ),
+ )
+ parser.add_argument(
+ '--no-sort-keys',
+ action='store_true',
+ dest='no_sort_keys',
+ default=False,
+ help='Keep JSON nodes in the same order',
+ )
+ parser.add_argument(
+ '--top-keys',
+ type=parse_topkeys,
+ dest='top_keys',
+ default=[],
+ help='Ordered list of keys to keep at the top of JSON hashes',
+ )
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ status = 0
+
+ for json_file in args.filenames:
+ with open(json_file, encoding='UTF-8') as f:
+ contents = f.read()
+
+ try:
+ pretty_contents = _get_pretty_format(
+ contents, args.indent, ensure_ascii=not args.no_ensure_ascii,
+ sort_keys=not args.no_sort_keys, top_keys=args.top_keys,
+ )
+ except ValueError:
+ print(
+ f'Input File {json_file} is not a valid JSON, consider using '
+ f'check-json',
+ )
+ return 1
+
+ if contents != pretty_contents:
+ if args.autofix:
+ _autofix(json_file, pretty_contents)
+ else:
+ diff_output = get_diff(contents, pretty_contents, json_file)
+ sys.stdout.buffer.write(diff_output.encode())
+
+ status = 1
+
+ return status
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/removed.py b/pre_commit_hooks/removed.py
new file mode 100644
index 0000000..6f6c7b7
--- /dev/null
+++ b/pre_commit_hooks/removed.py
@@ -0,0 +1,16 @@
+from __future__ import annotations
+
+import sys
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ argv = argv if argv is not None else sys.argv[1:]
+ hookid, new_hookid, url = argv[:3]
+ raise SystemExit(
+ f'`{hookid}` has been removed -- use `{new_hookid}` from {url}',
+ )
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py
new file mode 100644
index 0000000..5884394
--- /dev/null
+++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -0,0 +1,153 @@
+from __future__ import annotations
+
+import argparse
+import re
+from typing import IO
+from typing import Sequence
+
+
+PASS = 0
+FAIL = 1
+
+
+class Requirement:
+ UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?')
+ UNTIL_SEP = re.compile(rb'[^;\s]+')
+
+ def __init__(self) -> None:
+ self.value: bytes | None = None
+ self.comments: list[bytes] = []
+
+ @property
+ def name(self) -> bytes:
+ assert self.value is not None, self.value
+ name = self.value.lower()
+ for egg in (b'#egg=', b'&egg='):
+ if egg in self.value:
+ return name.partition(egg)[-1]
+
+ m = self.UNTIL_SEP.match(name)
+ assert m is not None
+
+ name = m.group()
+ m = self.UNTIL_COMPARISON.search(name)
+ if not m:
+ return name
+
+ return name[:m.start()]
+
+ def __lt__(self, requirement: Requirement) -> bool:
+ # \n means top of file comment, so always return True,
+ # otherwise just do a string comparison with value.
+ assert self.value is not None, self.value
+ if self.value == b'\n':
+ return True
+ elif requirement.value == b'\n':
+ return False
+ else:
+ return self.name < requirement.name
+
+ def is_complete(self) -> bool:
+ return (
+ self.value is not None and
+ not self.value.rstrip(b'\r\n').endswith(b'\\')
+ )
+
+ def append_value(self, value: bytes) -> None:
+ if self.value is not None:
+ self.value += value
+ else:
+ self.value = value
+
+
+def fix_requirements(f: IO[bytes]) -> int:
+ requirements: list[Requirement] = []
+ before = list(f)
+ after: list[bytes] = []
+
+ before_string = b''.join(before)
+
+ # adds new line in case one is missing
+ # AND a change to the requirements file is needed regardless:
+ if before and not before[-1].endswith(b'\n'):
+ before[-1] += b'\n'
+
+ # If the file is empty (i.e. only whitespace/newlines) exit early
+ if before_string.strip() == b'':
+ return PASS
+
+ for line in before:
+ # If the most recent requirement object has a value, then it's
+ # time to start building the next requirement object.
+
+ if not len(requirements) or requirements[-1].is_complete():
+ requirements.append(Requirement())
+
+ requirement = requirements[-1]
+
+ # If we see a newline before any requirements, then this is a
+ # top of file comment.
+ if len(requirements) == 1 and line.strip() == b'':
+ if (
+ len(requirement.comments) and
+ requirement.comments[0].startswith(b'#')
+ ):
+ requirement.value = b'\n'
+ else:
+ requirement.comments.append(line)
+ elif line.lstrip().startswith(b'#') or line.strip() == b'':
+ requirement.comments.append(line)
+ else:
+ requirement.append_value(line)
+
+ # if a file ends in a comment, preserve it at the end
+ if requirements[-1].value is None:
+ rest = requirements.pop().comments
+ else:
+ rest = []
+
+ # find and remove pkg-resources==0.0.0
+ # which is automatically added by broken pip package under Debian
+ requirements = [
+ req for req in requirements
+ if req.value != b'pkg-resources==0.0.0\n'
+ ]
+
+ for requirement in sorted(requirements):
+ after.extend(requirement.comments)
+ assert requirement.value, requirement.value
+ after.append(requirement.value)
+ after.extend(rest)
+
+ after_string = b''.join(after)
+
+ if before_string == after_string:
+ return PASS
+ else:
+ f.seek(0)
+ f.write(after_string)
+ f.truncate()
+ return FAIL
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = PASS
+
+ for arg in args.filenames:
+ with open(arg, 'rb+') as file_obj:
+ ret_for_file = fix_requirements(file_obj)
+
+ if ret_for_file:
+ print(f'Sorting {arg}')
+
+ retv |= ret_for_file
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py
new file mode 100644
index 0000000..116b5c1
--- /dev/null
+++ b/pre_commit_hooks/sort_simple_yaml.py
@@ -0,0 +1,125 @@
+"""Sort a simple YAML file, keeping blocks of comments and definitions
+together.
+
+We assume a strict subset of YAML that looks like:
+
+ # block of header comments
+ # here that should always
+ # be at the top of the file
+
+ # optional comments
+ # can go here
+ key: value
+ key: value
+
+ key: value
+
+In other words, we don't sort deeper than the top layer, and might corrupt
+complicated YAML files.
+"""
+from __future__ import annotations
+
+import argparse
+from typing import Sequence
+
+
+QUOTES = ["'", '"']
+
+
+def sort(lines: list[str]) -> list[str]:
+ """Sort a YAML file in alphabetical order, keeping blocks together.
+
+ :param lines: array of strings (without newlines)
+ :return: sorted array of strings
+ """
+ # make a copy of lines since we will clobber it
+ lines = list(lines)
+ new_lines = parse_block(lines, header=True)
+
+ for block in sorted(parse_blocks(lines), key=first_key):
+ if new_lines:
+ new_lines.append('')
+ new_lines.extend(block)
+
+ return new_lines
+
+
+def parse_block(lines: list[str], header: bool = False) -> list[str]:
+ """Parse and return a single block, popping off the start of `lines`.
+
+ If parsing a header block, we stop after we reach a line that is not a
+ comment. Otherwise, we stop after reaching an empty line.
+
+ :param lines: list of lines
+ :param header: whether we are parsing a header block
+ :return: list of lines that form the single block
+ """
+ block_lines = []
+ while lines and lines[0] and (not header or lines[0].startswith('#')):
+ block_lines.append(lines.pop(0))
+ return block_lines
+
+
+def parse_blocks(lines: list[str]) -> list[list[str]]:
+ """Parse and return all possible blocks, popping off the start of `lines`.
+
+ :param lines: list of lines
+ :return: list of blocks, where each block is a list of lines
+ """
+ blocks = []
+
+ while lines:
+ if lines[0] == '':
+ lines.pop(0)
+ else:
+ blocks.append(parse_block(lines))
+
+ return blocks
+
+
+def first_key(lines: list[str]) -> str:
+ """Returns a string representing the sort key of a block.
+
+ The sort key is the first YAML key we encounter, ignoring comments, and
+ stripping leading quotes.
+
+ >>> print(test)
+ # some comment
+ 'foo': true
+ >>> first_key(test)
+ 'foo'
+ """
+ for line in lines:
+ if line.startswith('#'):
+ continue
+ if any(line.startswith(quote) for quote in QUOTES):
+ return line[1:]
+ return line
+ else:
+ return '' # not actually reached in reality
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retval = 0
+
+ for filename in args.filenames:
+ with open(filename, 'r+') as f:
+ lines = [line.rstrip() for line in f.readlines()]
+ new_lines = sort(lines)
+
+ if lines != new_lines:
+ print(f'Fixing file `{filename}`')
+ f.seek(0)
+ f.write('\n'.join(new_lines) + '\n')
+ f.truncate()
+ retval = 1
+
+ return retval
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py
new file mode 100644
index 0000000..d1b1c4a
--- /dev/null
+++ b/pre_commit_hooks/string_fixer.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+import argparse
+import io
+import re
+import sys
+import tokenize
+from typing import Sequence
+
+if sys.version_info >= (3, 12): # pragma: >=3.12 cover
+ FSTRING_START = tokenize.FSTRING_START
+ FSTRING_END = tokenize.FSTRING_END
+else: # pragma: <3.12 cover
+ FSTRING_START = FSTRING_END = -1
+
+START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
+
+
+def handle_match(token_text: str) -> str:
+ if '"""' in token_text or "'''" in token_text:
+ return token_text
+
+ match = START_QUOTE_RE.match(token_text)
+ if match is not None:
+ meat = token_text[match.end():-1]
+ if '"' in meat or "'" in meat:
+ return token_text
+ else:
+ return match.group().replace('"', "'") + meat + "'"
+ else:
+ return token_text
+
+
+def get_line_offsets_by_line_no(src: str) -> list[int]:
+ # Padded so we can index with line number
+ offsets = [-1, 0]
+ for line in src.splitlines(True):
+ offsets.append(offsets[-1] + len(line))
+ return offsets
+
+
+def fix_strings(filename: str) -> int:
+ with open(filename, encoding='UTF-8', newline='') as f:
+ contents = f.read()
+ line_offsets = get_line_offsets_by_line_no(contents)
+
+ # Basically a mutable string
+ splitcontents = list(contents)
+
+ fstring_depth = 0
+
+ # Iterate in reverse so the offsets are always correct
+ tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline))
+ tokens = reversed(tokens_l)
+ for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens:
+ if token_type == FSTRING_START: # pragma: >=3.12 cover
+ fstring_depth += 1
+ elif token_type == FSTRING_END: # pragma: >=3.12 cover
+ fstring_depth -= 1
+ elif fstring_depth == 0 and token_type == tokenize.STRING:
+ new_text = handle_match(token_text)
+ splitcontents[
+ line_offsets[srow] + scol:
+ line_offsets[erow] + ecol
+ ] = new_text
+
+ new_contents = ''.join(splitcontents)
+ if contents != new_contents:
+ with open(filename, 'w', encoding='UTF-8', newline='') as f:
+ f.write(new_contents)
+ return 1
+ else:
+ return 0
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ return_value = fix_strings(filename)
+ if return_value != 0:
+ print(f'Fixing strings in {filename}')
+ retv |= return_value
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py
new file mode 100644
index 0000000..e7842af
--- /dev/null
+++ b/pre_commit_hooks/tests_should_end_in_test.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import argparse
+import os.path
+import re
+from typing import Sequence
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*')
+ mutex = parser.add_mutually_exclusive_group()
+ mutex.add_argument(
+ '--pytest',
+ dest='pattern',
+ action='store_const',
+ const=r'.*_test\.py',
+ default=r'.*_test\.py',
+ help='(the default) ensure tests match %(const)s',
+ )
+ mutex.add_argument(
+ '--pytest-test-first',
+ dest='pattern',
+ action='store_const',
+ const=r'test_.*\.py',
+ help='ensure tests match %(const)s',
+ )
+ mutex.add_argument(
+ '--django', '--unittest',
+ dest='pattern',
+ action='store_const',
+ const=r'test.*\.py',
+ help='ensure tests match %(const)s',
+ )
+ args = parser.parse_args(argv)
+
+ retcode = 0
+ reg = re.compile(args.pattern)
+ for filename in args.filenames:
+ base = os.path.basename(filename)
+ if (
+ not reg.fullmatch(base) and
+ not base == '__init__.py' and
+ not base == 'conftest.py'
+ ):
+ retcode = 1
+ print(f'{filename} does not match pattern "{args.pattern}"')
+
+ return retcode
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py
new file mode 100644
index 0000000..84f5067
--- /dev/null
+++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Sequence
+
+
+def _fix_file(
+ filename: str,
+ is_markdown: bool,
+ chars: bytes | None,
+) -> bool:
+ with open(filename, mode='rb') as file_processed:
+ lines = file_processed.readlines()
+ newlines = [_process_line(line, is_markdown, chars) for line in lines]
+ if newlines != lines:
+ with open(filename, mode='wb') as file_processed:
+ for line in newlines:
+ file_processed.write(line)
+ return True
+ else:
+ return False
+
+
+def _process_line(
+ line: bytes,
+ is_markdown: bool,
+ chars: bytes | None,
+) -> bytes:
+ if line[-2:] == b'\r\n':
+ eol = b'\r\n'
+ line = line[:-2]
+ elif line[-1:] == b'\n':
+ eol = b'\n'
+ line = line[:-1]
+ else:
+ eol = b''
+ # preserve trailing two-space for non-blank lines in markdown files
+ if is_markdown and (not line.isspace()) and line.endswith(b' '):
+ return line[:-2].rstrip(chars) + b' ' + eol
+ return line.rstrip(chars) + eol
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--no-markdown-linebreak-ext',
+ action='store_true',
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ '--markdown-linebreak-ext',
+ action='append',
+ default=[],
+ metavar='*|EXT[,EXT,...]',
+ help=(
+ 'Markdown extensions (or *) to not strip linebreak spaces. '
+ 'default: %(default)s'
+ ),
+ )
+ parser.add_argument(
+ '--chars',
+ help=(
+ 'The set of characters to strip from the end of lines. '
+ 'Defaults to all whitespace characters.'
+ ),
+ )
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ if args.no_markdown_linebreak_ext:
+ print('--no-markdown-linebreak-ext now does nothing!')
+
+ md_args = args.markdown_linebreak_ext
+ if '' in md_args:
+ parser.error('--markdown-linebreak-ext requires a non-empty argument')
+ all_markdown = '*' in md_args
+ # normalize extensions; split at ',', lowercase, and force 1 leading '.'
+ md_exts = [
+ '.' + x.lower().lstrip('.') for x in ','.join(md_args).split(',')
+ ]
+
+ # reject probable "eaten" filename as extension: skip leading '.' with [1:]
+ for ext in md_exts:
+ if any(c in ext[1:] for c in r'./\:'):
+ parser.error(
+ f'bad --markdown-linebreak-ext extension '
+ f'{ext!r} (has . / \\ :)\n'
+ f" (probably filename; use '--markdown-linebreak-ext=EXT')",
+ )
+ chars = None if args.chars is None else args.chars.encode()
+ return_code = 0
+ for filename in args.filenames:
+ _, extension = os.path.splitext(filename.lower())
+ md = all_markdown or extension in md_exts
+ if _fix_file(filename, md, chars):
+ print(f'Fixing {filename}')
+ return_code = 1
+ return return_code
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())
diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py
new file mode 100644
index 0000000..d6c90ae
--- /dev/null
+++ b/pre_commit_hooks/util.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import subprocess
+from typing import Any
+
+
+class CalledProcessError(RuntimeError):
+ pass
+
+
+def added_files() -> set[str]:
+ cmd = ('git', 'diff', '--staged', '--name-only', '--diff-filter=A')
+ return set(cmd_output(*cmd).splitlines())
+
+
+def cmd_output(*cmd: str, retcode: int | None = 0, **kwargs: Any) -> str:
+ kwargs.setdefault('stdout', subprocess.PIPE)
+ kwargs.setdefault('stderr', subprocess.PIPE)
+ proc = subprocess.Popen(cmd, **kwargs)
+ stdout, stderr = proc.communicate()
+ stdout = stdout.decode()
+ if retcode is not None and proc.returncode != retcode:
+ raise CalledProcessError(cmd, retcode, proc.returncode, stdout, stderr)
+ return stdout
+
+
+def zsplit(s: str) -> list[str]:
+ s = s.strip('\0')
+ if s:
+ return s.split('\0')
+ else:
+ return []