summaryrefslogtreecommitdiffstats
path: root/sphinx/util/parallel.py
blob: 0afdff9dee5c226ca368935e3e98104fc73af36b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Parallel building utilities."""

from __future__ import annotations

import os
import time
import traceback
from math import sqrt
from typing import TYPE_CHECKING, Any, Callable

try:
    import multiprocessing
    HAS_MULTIPROCESSING = True
except ImportError:
    HAS_MULTIPROCESSING = False

from sphinx.errors import SphinxParallelError
from sphinx.util import logging

if TYPE_CHECKING:
    from collections.abc import Sequence

logger = logging.getLogger(__name__)

# our parallel functionality only works for the forking Process
parallel_available = multiprocessing and os.name == 'posix'


class SerialTasks:
    """Has the same interface as ParallelTasks, but executes tasks directly."""

    def __init__(self, nproc: int = 1) -> None:
        pass

    def add_task(
        self, task_func: Callable, arg: Any = None, result_func: Callable | None = None,
    ) -> None:
        if arg is not None:
            res = task_func(arg)
        else:
            res = task_func()
        if result_func:
            result_func(res)

    def join(self) -> None:
        pass


class ParallelTasks:
    """Executes *nproc* tasks in parallel after forking."""

    def __init__(self, nproc: int) -> None:
        self.nproc = nproc
        # (optional) function performed by each task on the result of main task
        self._result_funcs: dict[int, Callable] = {}
        # task arguments
        self._args: dict[int, list[Any] | None] = {}
        # list of subprocesses (both started and waiting)
        self._procs: dict[int, Any] = {}
        # list of receiving pipe connections of running subprocesses
        self._precvs: dict[int, Any] = {}
        # list of receiving pipe connections of waiting subprocesses
        self._precvsWaiting: dict[int, Any] = {}
        # number of working subprocesses
        self._pworking = 0
        # task number of each subprocess
        self._taskid = 0

    def _process(self, pipe: Any, func: Callable, arg: Any) -> None:
        try:
            collector = logging.LogCollector()
            with collector.collect():
                if arg is None:
                    ret = func()
                else:
                    ret = func(arg)
            failed = False
        except BaseException as err:
            failed = True
            errmsg = traceback.format_exception_only(err.__class__, err)[0].strip()
            ret = (errmsg, traceback.format_exc())
        logging.convert_serializable(collector.logs)
        pipe.send((failed, collector.logs, ret))

    def add_task(
        self, task_func: Callable, arg: Any = None, result_func: Callable | None = None,
    ) -> None:
        tid = self._taskid
        self._taskid += 1
        self._result_funcs[tid] = result_func or (lambda arg, result: None)
        self._args[tid] = arg
        precv, psend = multiprocessing.Pipe(False)
        context: Any = multiprocessing.get_context('fork')
        proc = context.Process(target=self._process, args=(psend, task_func, arg))
        self._procs[tid] = proc
        self._precvsWaiting[tid] = precv
        self._join_one()

    def join(self) -> None:
        try:
            while self._pworking:
                if not self._join_one():
                    time.sleep(0.02)
        finally:
            # shutdown other child processes on failure
            self.terminate()

    def terminate(self) -> None:
        for tid in list(self._precvs):
            self._procs[tid].terminate()
            self._result_funcs.pop(tid)
            self._procs.pop(tid)
            self._precvs.pop(tid)
            self._pworking -= 1

    def _join_one(self) -> bool:
        joined_any = False
        for tid, pipe in self._precvs.items():
            if pipe.poll():
                exc, logs, result = pipe.recv()
                if exc:
                    raise SphinxParallelError(*result)
                for log in logs:
                    logger.handle(log)
                self._result_funcs.pop(tid)(self._args.pop(tid), result)
                self._procs[tid].join()
                self._precvs.pop(tid)
                self._pworking -= 1
                joined_any = True
                break

        while self._precvsWaiting and self._pworking < self.nproc:
            newtid, newprecv = self._precvsWaiting.popitem()
            self._precvs[newtid] = newprecv
            self._procs[newtid].start()
            self._pworking += 1

        return joined_any


def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]:
    # determine how many documents to read in one go
    nargs = len(arguments)
    chunksize = nargs // nproc
    if chunksize >= maxbatch:
        # try to improve batch size vs. number of batches
        chunksize = int(sqrt(nargs / nproc * maxbatch))
    if chunksize == 0:
        chunksize = 1
    nchunks, rest = divmod(nargs, chunksize)
    if rest:
        nchunks += 1
    # partition documents in "chunks" that will be written by one Process
    return [arguments[i * chunksize:(i + 1) * chunksize] for i in range(nchunks)]