summaryrefslogtreecommitdiffstats
path: root/sphinx/util/parallel.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/util/parallel.py')
-rw-r--r--sphinx/util/parallel.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py
new file mode 100644
index 0000000..0afdff9
--- /dev/null
+++ b/sphinx/util/parallel.py
@@ -0,0 +1,154 @@
+"""Parallel building utilities."""
+
+from __future__ import annotations
+
+import os
+import time
+import traceback
+from math import sqrt
+from typing import TYPE_CHECKING, Any, Callable
+
+try:
+ import multiprocessing
+ HAS_MULTIPROCESSING = True
+except ImportError:
+ HAS_MULTIPROCESSING = False
+
+from sphinx.errors import SphinxParallelError
+from sphinx.util import logging
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+logger = logging.getLogger(__name__)
+
+# our parallel functionality only works for the forking Process
+parallel_available = multiprocessing and os.name == 'posix'
+
+
+class SerialTasks:
+ """Has the same interface as ParallelTasks, but executes tasks directly."""
+
+ def __init__(self, nproc: int = 1) -> None:
+ pass
+
+ def add_task(
+ self, task_func: Callable, arg: Any = None, result_func: Callable | None = None,
+ ) -> None:
+ if arg is not None:
+ res = task_func(arg)
+ else:
+ res = task_func()
+ if result_func:
+ result_func(res)
+
+ def join(self) -> None:
+ pass
+
+
+class ParallelTasks:
+ """Executes *nproc* tasks in parallel after forking."""
+
+ def __init__(self, nproc: int) -> None:
+ self.nproc = nproc
+ # (optional) function performed by each task on the result of main task
+ self._result_funcs: dict[int, Callable] = {}
+ # task arguments
+ self._args: dict[int, list[Any] | None] = {}
+ # list of subprocesses (both started and waiting)
+ self._procs: dict[int, Any] = {}
+ # list of receiving pipe connections of running subprocesses
+ self._precvs: dict[int, Any] = {}
+ # list of receiving pipe connections of waiting subprocesses
+ self._precvsWaiting: dict[int, Any] = {}
+ # number of working subprocesses
+ self._pworking = 0
+ # task number of each subprocess
+ self._taskid = 0
+
+ def _process(self, pipe: Any, func: Callable, arg: Any) -> None:
+ try:
+ collector = logging.LogCollector()
+ with collector.collect():
+ if arg is None:
+ ret = func()
+ else:
+ ret = func(arg)
+ failed = False
+ except BaseException as err:
+ failed = True
+ errmsg = traceback.format_exception_only(err.__class__, err)[0].strip()
+ ret = (errmsg, traceback.format_exc())
+ logging.convert_serializable(collector.logs)
+ pipe.send((failed, collector.logs, ret))
+
+ def add_task(
+ self, task_func: Callable, arg: Any = None, result_func: Callable | None = None,
+ ) -> None:
+ tid = self._taskid
+ self._taskid += 1
+ self._result_funcs[tid] = result_func or (lambda arg, result: None)
+ self._args[tid] = arg
+ precv, psend = multiprocessing.Pipe(False)
+ context: Any = multiprocessing.get_context('fork')
+ proc = context.Process(target=self._process, args=(psend, task_func, arg))
+ self._procs[tid] = proc
+ self._precvsWaiting[tid] = precv
+ self._join_one()
+
+ def join(self) -> None:
+ try:
+ while self._pworking:
+ if not self._join_one():
+ time.sleep(0.02)
+ finally:
+ # shutdown other child processes on failure
+ self.terminate()
+
+ def terminate(self) -> None:
+ for tid in list(self._precvs):
+ self._procs[tid].terminate()
+ self._result_funcs.pop(tid)
+ self._procs.pop(tid)
+ self._precvs.pop(tid)
+ self._pworking -= 1
+
+ def _join_one(self) -> bool:
+ joined_any = False
+ for tid, pipe in self._precvs.items():
+ if pipe.poll():
+ exc, logs, result = pipe.recv()
+ if exc:
+ raise SphinxParallelError(*result)
+ for log in logs:
+ logger.handle(log)
+ self._result_funcs.pop(tid)(self._args.pop(tid), result)
+ self._procs[tid].join()
+ self._precvs.pop(tid)
+ self._pworking -= 1
+ joined_any = True
+ break
+
+ while self._precvsWaiting and self._pworking < self.nproc:
+ newtid, newprecv = self._precvsWaiting.popitem()
+ self._precvs[newtid] = newprecv
+ self._procs[newtid].start()
+ self._pworking += 1
+
+ return joined_any
+
+
+def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]:
+ # determine how many documents to read in one go
+ nargs = len(arguments)
+ chunksize = nargs // nproc
+ if chunksize >= maxbatch:
+ # try to improve batch size vs. number of batches
+ chunksize = int(sqrt(nargs / nproc * maxbatch))
+ if chunksize == 0:
+ chunksize = 1
+ nchunks, rest = divmod(nargs, chunksize)
+ if rest:
+ nchunks += 1
+ # partition documents in "chunks" that will be written by one Process
+ return [arguments[i * chunksize:(i + 1) * chunksize] for i in range(nchunks)]