diff options
Diffstat (limited to '')
-rw-r--r-- | pre_commit/xargs.py | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/pre_commit/xargs.py b/pre_commit/xargs.py new file mode 100644 index 0000000..22580f5 --- /dev/null +++ b/pre_commit/xargs.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import math +import multiprocessing +import os +import subprocess +import sys +from collections.abc import Generator +from collections.abc import Iterable +from collections.abc import MutableMapping +from collections.abc import Sequence +from typing import Any +from typing import Callable +from typing import TypeVar + +from pre_commit import parse_shebang +from pre_commit.util import cmd_output_b +from pre_commit.util import cmd_output_p + +TArg = TypeVar('TArg') +TRet = TypeVar('TRet') + + +def cpu_count() -> int: + try: + # On systems that support it, this will return a more accurate count of + # usable CPUs for the current process, which will take into account + # cgroup limits + return len(os.sched_getaffinity(0)) + except AttributeError: + pass + + try: + return multiprocessing.cpu_count() + except NotImplementedError: + return 1 + + +def _environ_size(_env: MutableMapping[str, str] | None = None) -> int: + environ = _env if _env is not None else getattr(os, 'environb', os.environ) + size = 8 * len(environ) # number of pointers in `envp` + for k, v in environ.items(): + size += len(k) + len(v) + 2 # c strings in `envp` + return size + + +def _get_platform_max_length() -> int: # pragma: no cover (platform specific) + if os.name == 'posix': + maximum = os.sysconf('SC_ARG_MAX') - 2048 - _environ_size() + maximum = max(min(maximum, 2 ** 17), 2 ** 12) + return maximum + elif os.name == 'nt': + return 2 ** 15 - 2048 # UNICODE_STRING max - headroom + else: + # posix minimum + return 2 ** 12 + + +def _command_length(*cmd: str) -> int: + full_cmd = ' '.join(cmd) + + # win32 uses the amount of characters, more details at: + # https://github.com/pre-commit/pre-commit/pull/839 + if sys.platform == 'win32': + return len(full_cmd.encode('utf-16le')) // 2 + else: + return len(full_cmd.encode(sys.getfilesystemencoding())) + + +class ArgumentTooLongError(RuntimeError): + pass + + +def partition( + cmd: Sequence[str], + varargs: Sequence[str], + target_concurrency: int, + _max_length: int | None = None, +) -> tuple[tuple[str, ...], ...]: + _max_length = _max_length or _get_platform_max_length() + + # Generally, we try to partition evenly into at least `target_concurrency` + # partitions, but we don't want a bunch of tiny partitions. + max_args = max(4, math.ceil(len(varargs) / target_concurrency)) + + cmd = tuple(cmd) + ret = [] + + ret_cmd: list[str] = [] + # Reversed so arguments are in order + varargs = list(reversed(varargs)) + + total_length = _command_length(*cmd) + 1 + while varargs: + arg = varargs.pop() + + arg_length = _command_length(arg) + 1 + if ( + total_length + arg_length <= _max_length and + len(ret_cmd) < max_args + ): + ret_cmd.append(arg) + total_length += arg_length + elif not ret_cmd: + raise ArgumentTooLongError(arg) + else: + # We've exceeded the length, yield a command + ret.append(cmd + tuple(ret_cmd)) + ret_cmd = [] + total_length = _command_length(*cmd) + 1 + varargs.append(arg) + + ret.append(cmd + tuple(ret_cmd)) + + return tuple(ret) + + +@contextlib.contextmanager +def _thread_mapper(maxsize: int) -> Generator[ + Callable[[Callable[[TArg], TRet], Iterable[TArg]], Iterable[TRet]], + None, None, +]: + if maxsize == 1: + yield map + else: + with concurrent.futures.ThreadPoolExecutor(maxsize) as ex: + yield ex.map + + +def xargs( + cmd: tuple[str, ...], + varargs: Sequence[str], + *, + color: bool = False, + target_concurrency: int = 1, + _max_length: int = _get_platform_max_length(), + **kwargs: Any, +) -> tuple[int, bytes]: + """A simplified implementation of xargs. + + color: Make a pty if on a platform that supports it + target_concurrency: Target number of partitions to run concurrently + """ + cmd_fn = cmd_output_p if color else cmd_output_b + retcode = 0 + stdout = b'' + + try: + cmd = parse_shebang.normalize_cmd(cmd) + except parse_shebang.ExecutableNotFoundError as e: + return e.to_output()[:2] + + # on windows, batch files have a separate length limit than windows itself + if ( + sys.platform == 'win32' and + cmd[0].lower().endswith(('.bat', '.cmd')) + ): # pragma: win32 cover + # this is implementation details but the command gets translated into + # full/path/to/cmd.exe /c *cmd + cmd_exe = parse_shebang.find_executable('cmd.exe') + # 1024 is additionally subtracted to give headroom for further + # expansion inside the batch file + _max_length = 8192 - len(cmd_exe) - len(' /c ') - 1024 + + partitions = partition(cmd, varargs, target_concurrency, _max_length) + + def run_cmd_partition( + run_cmd: tuple[str, ...], + ) -> tuple[int, bytes, bytes | None]: + return cmd_fn( + *run_cmd, check=False, stderr=subprocess.STDOUT, **kwargs, + ) + + threads = min(len(partitions), target_concurrency) + with _thread_mapper(threads) as thread_map: + results = thread_map(run_cmd_partition, partitions) + + for proc_retcode, proc_out, _ in results: + if abs(proc_retcode) > abs(retcode): + retcode = proc_retcode + stdout += proc_out + + return retcode, stdout |