summaryrefslogtreecommitdiffstats
path: root/pre_commit_hooks/requirements_txt_fixer.py
diff options
context:
space:
mode:
Diffstat (limited to 'pre_commit_hooks/requirements_txt_fixer.py')
-rw-r--r--pre_commit_hooks/requirements_txt_fixer.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py
new file mode 100644
index 0000000..63f891f
--- /dev/null
+++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -0,0 +1,153 @@
+import argparse
+import re
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Sequence
+
+
+PASS = 0
+FAIL = 1
+
+
+class Requirement:
+ UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?')
+ UNTIL_SEP = re.compile(rb'[^;\s]+')
+
+ def __init__(self) -> None:
+ self.value: Optional[bytes] = None
+ self.comments: List[bytes] = []
+
+ @property
+ def name(self) -> bytes:
+ assert self.value is not None, self.value
+ name = self.value.lower()
+ for egg in (b'#egg=', b'&egg='):
+ if egg in self.value:
+ return name.partition(egg)[-1]
+
+ m = self.UNTIL_SEP.match(name)
+ assert m is not None
+
+ name = m.group()
+ m = self.UNTIL_COMPARISON.search(name)
+ if not m:
+ return name
+
+ return name[:m.start()]
+
+ def __lt__(self, requirement: 'Requirement') -> bool:
+ # \n means top of file comment, so always return True,
+ # otherwise just do a string comparison with value.
+ assert self.value is not None, self.value
+ if self.value == b'\n':
+ return True
+ elif requirement.value == b'\n':
+ return False
+ else:
+ return self.name < requirement.name
+
+ def is_complete(self) -> bool:
+ return (
+ self.value is not None and
+ not self.value.rstrip(b'\r\n').endswith(b'\\')
+ )
+
+ def append_value(self, value: bytes) -> None:
+ if self.value is not None:
+ self.value += value
+ else:
+ self.value = value
+
+
+def fix_requirements(f: IO[bytes]) -> int:
+ requirements: List[Requirement] = []
+ before = list(f)
+ after: List[bytes] = []
+
+ before_string = b''.join(before)
+
+ # adds new line in case one is missing
+ # AND a change to the requirements file is needed regardless:
+ if before and not before[-1].endswith(b'\n'):
+ before[-1] += b'\n'
+
+ # If the file is empty (i.e. only whitespace/newlines) exit early
+ if before_string.strip() == b'':
+ return PASS
+
+ for line in before:
+ # If the most recent requirement object has a value, then it's
+ # time to start building the next requirement object.
+
+ if not len(requirements) or requirements[-1].is_complete():
+ requirements.append(Requirement())
+
+ requirement = requirements[-1]
+
+ # If we see a newline before any requirements, then this is a
+ # top of file comment.
+ if len(requirements) == 1 and line.strip() == b'':
+ if (
+ len(requirement.comments) and
+ requirement.comments[0].startswith(b'#')
+ ):
+ requirement.value = b'\n'
+ else:
+ requirement.comments.append(line)
+ elif line.lstrip().startswith(b'#') or line.strip() == b'':
+ requirement.comments.append(line)
+ else:
+ requirement.append_value(line)
+
+ # if a file ends in a comment, preserve it at the end
+ if requirements[-1].value is None:
+ rest = requirements.pop().comments
+ else:
+ rest = []
+
+ # find and remove pkg-resources==0.0.0
+ # which is automatically added by broken pip package under Debian
+ requirements = [
+ req for req in requirements
+ if req.value != b'pkg-resources==0.0.0\n'
+ ]
+
+ for requirement in sorted(requirements):
+ after.extend(requirement.comments)
+ assert requirement.value, requirement.value
+ after.append(requirement.value)
+ after.extend(rest)
+
+ after_string = b''.join(after)
+
+ if before_string == after_string:
+ return PASS
+ else:
+ f.seek(0)
+ f.write(after_string)
+ f.truncate()
+ return FAIL
+
+
+def main(argv: Optional[Sequence[str]] = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = PASS
+
+ for arg in args.filenames:
+ with open(arg, 'rb+') as file_obj:
+ ret_for_file = fix_requirements(file_obj)
+
+ if ret_for_file:
+ print(f'Sorting {arg}')
+
+ retv |= ret_for_file
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())