From d4583dcad7d68d3c1503b04ec0d3364809304807 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 30 Jan 2022 12:02:58 +0100 Subject: Adding upstream version 4.5.0+dfsg. Signed-off-by: Daniel Baumann --- pre_commit_hooks/requirements_txt_fixer.py | 153 +++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 pre_commit_hooks/requirements_txt_fixer.py (limited to 'pre_commit_hooks/requirements_txt_fixer.py') diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py new file mode 100644 index 0000000..5884394 --- /dev/null +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import argparse +import re +from typing import IO +from typing import Sequence + + +PASS = 0 +FAIL = 1 + + +class Requirement: + UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') + UNTIL_SEP = re.compile(rb'[^;\s]+') + + def __init__(self) -> None: + self.value: bytes | None = None + self.comments: list[bytes] = [] + + @property + def name(self) -> bytes: + assert self.value is not None, self.value + name = self.value.lower() + for egg in (b'#egg=', b'&egg='): + if egg in self.value: + return name.partition(egg)[-1] + + m = self.UNTIL_SEP.match(name) + assert m is not None + + name = m.group() + m = self.UNTIL_COMPARISON.search(name) + if not m: + return name + + return name[:m.start()] + + def __lt__(self, requirement: Requirement) -> bool: + # \n means top of file comment, so always return True, + # otherwise just do a string comparison with value. + assert self.value is not None, self.value + if self.value == b'\n': + return True + elif requirement.value == b'\n': + return False + else: + return self.name < requirement.name + + def is_complete(self) -> bool: + return ( + self.value is not None and + not self.value.rstrip(b'\r\n').endswith(b'\\') + ) + + def append_value(self, value: bytes) -> None: + if self.value is not None: + self.value += value + else: + self.value = value + + +def fix_requirements(f: IO[bytes]) -> int: + requirements: list[Requirement] = [] + before = list(f) + after: list[bytes] = [] + + before_string = b''.join(before) + + # adds new line in case one is missing + # AND a change to the requirements file is needed regardless: + if before and not before[-1].endswith(b'\n'): + before[-1] += b'\n' + + # If the file is empty (i.e. only whitespace/newlines) exit early + if before_string.strip() == b'': + return PASS + + for line in before: + # If the most recent requirement object has a value, then it's + # time to start building the next requirement object. + + if not len(requirements) or requirements[-1].is_complete(): + requirements.append(Requirement()) + + requirement = requirements[-1] + + # If we see a newline before any requirements, then this is a + # top of file comment. + if len(requirements) == 1 and line.strip() == b'': + if ( + len(requirement.comments) and + requirement.comments[0].startswith(b'#') + ): + requirement.value = b'\n' + else: + requirement.comments.append(line) + elif line.lstrip().startswith(b'#') or line.strip() == b'': + requirement.comments.append(line) + else: + requirement.append_value(line) + + # if a file ends in a comment, preserve it at the end + if requirements[-1].value is None: + rest = requirements.pop().comments + else: + rest = [] + + # find and remove pkg-resources==0.0.0 + # which is automatically added by broken pip package under Debian + requirements = [ + req for req in requirements + if req.value != b'pkg-resources==0.0.0\n' + ] + + for requirement in sorted(requirements): + after.extend(requirement.comments) + assert requirement.value, requirement.value + after.append(requirement.value) + after.extend(rest) + + after_string = b''.join(after) + + if before_string == after_string: + return PASS + else: + f.seek(0) + f.write(after_string) + f.truncate() + return FAIL + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = PASS + + for arg in args.filenames: + with open(arg, 'rb+') as file_obj: + ret_for_file = fix_requirements(file_obj) + + if ret_for_file: + print(f'Sorting {arg}') + + retv |= ret_for_file + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) -- cgit v1.2.3