summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/threads/resizable_parallel_runner.cc
blob: 238ccbf20c14cf969bfbc1e714c0378c99f12186 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <jxl/jxl_threads_export.h>
#include <jxl/memory_manager.h>
#include <jxl/parallel_runner.h>
#include <jxl/resizable_parallel_runner.h>

#include <algorithm>
#include <atomic>
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <mutex>
#include <thread>
#include <vector>

namespace jpegxl {
namespace {

// A thread pool that allows changing the number of threads it runs. It also
// runs tasks on the calling thread, which can work better on schedulers for
// heterogeneous architectures.
struct ResizeableParallelRunner {
  void SetNumThreads(size_t num) {
    if (num > 0) {
      num -= 1;
    }
    {
      std::unique_lock<std::mutex> l(state_mutex_);
      num_desired_workers_ = num;
      workers_can_proceed_.notify_all();
    }
    if (workers_.size() < num) {
      for (size_t i = workers_.size(); i < num; i++) {
        workers_.emplace_back([this, i]() { WorkerBody(i); });
      }
    }
    if (workers_.size() > num) {
      for (size_t i = num; i < workers_.size(); i++) {
        workers_[i].join();
      }
      workers_.resize(num);
    }
  }

  ~ResizeableParallelRunner() { SetNumThreads(0); }

  JxlParallelRetCode Run(void* jxl_opaque, JxlParallelRunInit init,
                         JxlParallelRunFunction func, uint32_t start,
                         uint32_t end) {
    if (start + 1 == end) {
      JxlParallelRetCode ret = init(jxl_opaque, 1);
      if (ret != 0) return ret;

      func(jxl_opaque, start, 0);
      return ret;
    }

    size_t num_workers = std::min<size_t>(workers_.size() + 1, end - start);
    JxlParallelRetCode ret = init(jxl_opaque, num_workers);
    if (ret != 0) {
      return ret;
    }

    {
      std::unique_lock<std::mutex> l(state_mutex_);
      // Avoid waking up more workers than needed.
      max_running_workers_ = end - start - 1;
      next_task_ = start;
      end_task_ = end;
      func_ = func;
      jxl_opaque_ = jxl_opaque;
      work_available_ = true;
      num_running_workers_++;
      workers_can_proceed_.notify_all();
    }

    DequeueTasks(0);

    while (true) {
      std::unique_lock<std::mutex> l(state_mutex_);
      if (num_running_workers_ == 0) break;
      work_done_.wait(l);
    }

    return ret;
  }

 private:
  void WorkerBody(size_t worker_id) {
    while (true) {
      {
        std::unique_lock<std::mutex> l(state_mutex_);
        // Worker pool was reduced, resize down.
        if (worker_id >= num_desired_workers_) {
          return;
        }
        // Nothing to do this time.
        if (!work_available_ || worker_id >= max_running_workers_) {
          workers_can_proceed_.wait(l);
          continue;
        }
        num_running_workers_++;
      }
      DequeueTasks(worker_id + 1);
    }
  }

  void DequeueTasks(size_t thread_id) {
    while (true) {
      uint32_t task = next_task_++;
      if (task >= end_task_) {
        std::unique_lock<std::mutex> l(state_mutex_);
        num_running_workers_--;
        work_available_ = false;
        if (num_running_workers_ == 0) {
          work_done_.notify_all();
        }
        break;
      }
      func_(jxl_opaque_, task, thread_id);
    }
  }

  // Checks when the worker has something to do, which can be one of:
  // - quitting (when worker_id >= num_desired_workers_)
  // - having work available for them (work_available_ is true and worker_id >=
  // max_running_workers_)
  std::condition_variable workers_can_proceed_;

  // Workers are done, and the main thread can proceed (num_running_workers_ ==
  // 0)
  std::condition_variable work_done_;

  std::vector<std::thread> workers_;

  // Protects all the remaining variables, except for func_, jxl_opaque_ and
  // end_task_ (for which only the write by the main thread is protected, and
  // subsequent uses by workers happen-after it) and next_task_ (which is
  // atomic).
  std::mutex state_mutex_;

  // Range of tasks still need to be done.
  std::atomic<uint32_t> next_task_;
  uint32_t end_task_;

  // Function to run and its argument.
  JxlParallelRunFunction func_;
  void* jxl_opaque_;  // not owned

  // Variables that control the workers:
  // - work_available_ is set to true after a call to Run() and to false at the
  // end of it.
  // - num_desired_workers_ represents the number of workers that should be
  // present.
  // - max_running_workers_ represents the number of workers that should be
  // executing tasks.
  // - num_running_workers_ represents the number of workers that are executing
  // tasks.
  size_t num_desired_workers_ = 0;
  size_t max_running_workers_ = 0;
  size_t num_running_workers_ = 0;
  bool work_available_ = false;
};
}  // namespace
}  // namespace jpegxl

extern "C" {
JXL_THREADS_EXPORT JxlParallelRetCode JxlResizableParallelRunner(
    void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init,
    JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) {
  return static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque)
      ->Run(jpegxl_opaque, init, func, start_range, end_range);
}

JXL_THREADS_EXPORT void* JxlResizableParallelRunnerCreate(
    const JxlMemoryManager* memory_manager) {
  return new jpegxl::ResizeableParallelRunner();
}

JXL_THREADS_EXPORT void JxlResizableParallelRunnerSetThreads(
    void* runner_opaque, size_t num_threads) {
  static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque)
      ->SetNumThreads(num_threads);
}

JXL_THREADS_EXPORT void JxlResizableParallelRunnerDestroy(void* runner_opaque) {
  delete static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque);
}

JXL_THREADS_EXPORT uint32_t
JxlResizableParallelRunnerSuggestThreads(uint64_t xsize, uint64_t ysize) {
  // ~one thread per group.
  return std::min<uint64_t>(std::thread::hardware_concurrency(),
                            xsize * ysize / (256 * 256));
}
}