diff options
Diffstat (limited to 'sphinx/util/parallel.py')
-rw-r--r-- | sphinx/util/parallel.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py new file mode 100644 index 0000000..0afdff9 --- /dev/null +++ b/sphinx/util/parallel.py @@ -0,0 +1,154 @@ +"""Parallel building utilities.""" + +from __future__ import annotations + +import os +import time +import traceback +from math import sqrt +from typing import TYPE_CHECKING, Any, Callable + +try: + import multiprocessing + HAS_MULTIPROCESSING = True +except ImportError: + HAS_MULTIPROCESSING = False + +from sphinx.errors import SphinxParallelError +from sphinx.util import logging + +if TYPE_CHECKING: + from collections.abc import Sequence + +logger = logging.getLogger(__name__) + +# our parallel functionality only works for the forking Process +parallel_available = multiprocessing and os.name == 'posix' + + +class SerialTasks: + """Has the same interface as ParallelTasks, but executes tasks directly.""" + + def __init__(self, nproc: int = 1) -> None: + pass + + def add_task( + self, task_func: Callable, arg: Any = None, result_func: Callable | None = None, + ) -> None: + if arg is not None: + res = task_func(arg) + else: + res = task_func() + if result_func: + result_func(res) + + def join(self) -> None: + pass + + +class ParallelTasks: + """Executes *nproc* tasks in parallel after forking.""" + + def __init__(self, nproc: int) -> None: + self.nproc = nproc + # (optional) function performed by each task on the result of main task + self._result_funcs: dict[int, Callable] = {} + # task arguments + self._args: dict[int, list[Any] | None] = {} + # list of subprocesses (both started and waiting) + self._procs: dict[int, Any] = {} + # list of receiving pipe connections of running subprocesses + self._precvs: dict[int, Any] = {} + # list of receiving pipe connections of waiting subprocesses + self._precvsWaiting: dict[int, Any] = {} + # number of working subprocesses + self._pworking = 0 + # task number of each subprocess + self._taskid = 0 + + def _process(self, pipe: Any, func: Callable, arg: Any) -> None: + try: + collector = logging.LogCollector() + with collector.collect(): + if arg is None: + ret = func() + else: + ret = func(arg) + failed = False + except BaseException as err: + failed = True + errmsg = traceback.format_exception_only(err.__class__, err)[0].strip() + ret = (errmsg, traceback.format_exc()) + logging.convert_serializable(collector.logs) + pipe.send((failed, collector.logs, ret)) + + def add_task( + self, task_func: Callable, arg: Any = None, result_func: Callable | None = None, + ) -> None: + tid = self._taskid + self._taskid += 1 + self._result_funcs[tid] = result_func or (lambda arg, result: None) + self._args[tid] = arg + precv, psend = multiprocessing.Pipe(False) + context: Any = multiprocessing.get_context('fork') + proc = context.Process(target=self._process, args=(psend, task_func, arg)) + self._procs[tid] = proc + self._precvsWaiting[tid] = precv + self._join_one() + + def join(self) -> None: + try: + while self._pworking: + if not self._join_one(): + time.sleep(0.02) + finally: + # shutdown other child processes on failure + self.terminate() + + def terminate(self) -> None: + for tid in list(self._precvs): + self._procs[tid].terminate() + self._result_funcs.pop(tid) + self._procs.pop(tid) + self._precvs.pop(tid) + self._pworking -= 1 + + def _join_one(self) -> bool: + joined_any = False + for tid, pipe in self._precvs.items(): + if pipe.poll(): + exc, logs, result = pipe.recv() + if exc: + raise SphinxParallelError(*result) + for log in logs: + logger.handle(log) + self._result_funcs.pop(tid)(self._args.pop(tid), result) + self._procs[tid].join() + self._precvs.pop(tid) + self._pworking -= 1 + joined_any = True + break + + while self._precvsWaiting and self._pworking < self.nproc: + newtid, newprecv = self._precvsWaiting.popitem() + self._precvs[newtid] = newprecv + self._procs[newtid].start() + self._pworking += 1 + + return joined_any + + +def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]: + # determine how many documents to read in one go + nargs = len(arguments) + chunksize = nargs // nproc + if chunksize >= maxbatch: + # try to improve batch size vs. number of batches + chunksize = int(sqrt(nargs / nproc * maxbatch)) + if chunksize == 0: + chunksize = 1 + nchunks, rest = divmod(nargs, chunksize) + if rest: + nchunks += 1 + # partition documents in "chunks" that will be written by one Process + return [arguments[i * chunksize:(i + 1) * chunksize] for i in range(nchunks)] |