diff options
Diffstat (limited to 'ml/dlib/dlib/threads/thread_pool_extension.cpp')
-rw-r--r-- | ml/dlib/dlib/threads/thread_pool_extension.cpp | 347 |
1 files changed, 347 insertions, 0 deletions
diff --git a/ml/dlib/dlib/threads/thread_pool_extension.cpp b/ml/dlib/dlib/threads/thread_pool_extension.cpp new file mode 100644 index 000000000..00d99b910 --- /dev/null +++ b/ml/dlib/dlib/threads/thread_pool_extension.cpp @@ -0,0 +1,347 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREAD_POOl_CPPh_ +#define DLIB_THREAD_POOl_CPPh_ + +#include "thread_pool_extension.h" +#include <memory> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + thread_pool_implementation:: + thread_pool_implementation ( + unsigned long num_threads + ) : + task_done_signaler(m), + task_ready_signaler(m), + we_are_destructing(false) + { + tasks.resize(num_threads); + threads.resize(num_threads); + for (unsigned long i = 0; i < num_threads; ++i) + { + threads[i] = std::thread([&](){this->thread();}); + } + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + shutdown_pool ( + ) + { + { + auto_mutex M(m); + + // first wait for all pending tasks to finish + bool found_task = true; + while (found_task) + { + found_task = false; + for (unsigned long i = 0; i < tasks.size(); ++i) + { + // If task bucket i has a task that is currently supposed to be processed + if (tasks[i].is_empty() == false) + { + found_task = true; + break; + } + } + + if (found_task) + task_done_signaler.wait(); + } + + // now tell the threads to kill themselves + we_are_destructing = true; + task_ready_signaler.broadcast(); + } + + // wait for all threads to terminate + for (auto& t : threads) + t.join(); + threads.clear(); + + // Throw any unhandled exceptions. Since shutdown_pool() is only called in the + // destructor this will kill the program. + for (auto&& task : tasks) + task.propagate_exception(); + } + +// ---------------------------------------------------------------------------------------- + + thread_pool_implementation:: + ~thread_pool_implementation() + { + shutdown_pool(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long thread_pool_implementation:: + num_threads_in_pool ( + ) const + { + auto_mutex M(m); + return tasks.size(); + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + wait_for_task ( + uint64 task_id + ) const + { + auto_mutex M(m); + if (tasks.size() != 0) + { + const unsigned long idx = task_id_to_index(task_id); + while (tasks[idx].task_id == task_id) + task_done_signaler.wait(); + + for (auto&& task : tasks) + task.propagate_exception(); + } + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + wait_for_all_tasks ( + ) const + { + const thread_id_type thread_id = get_thread_id(); + + auto_mutex M(m); + bool found_task = true; + while (found_task) + { + found_task = false; + for (unsigned long i = 0; i < tasks.size(); ++i) + { + // If task bucket i has a task that is currently supposed to be processed + // and it originated from the calling thread + if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id) + { + found_task = true; + break; + } + } + + if (found_task) + task_done_signaler.wait(); + } + + // throw any exceptions generated by the tasks + for (auto&& task : tasks) + task.propagate_exception(); + } + +// ---------------------------------------------------------------------------------------- + + bool thread_pool_implementation:: + is_worker_thread ( + const thread_id_type id + ) const + { + for (unsigned long i = 0; i < worker_thread_ids.size(); ++i) + { + if (worker_thread_ids[i] == id) + return true; + } + + // if there aren't any threads in the pool then we consider all threads + // to be worker threads + if (tasks.size() == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + thread ( + ) + { + { + // save the id of this worker thread into worker_thread_ids + auto_mutex M(m); + thread_id_type id = get_thread_id(); + worker_thread_ids.push_back(id); + } + + task_state_type task; + while (we_are_destructing == false) + { + long idx = 0; + + // wait for a task to do + { auto_mutex M(m); + while ( (idx = find_ready_task()) == -1 && we_are_destructing == false) + task_ready_signaler.wait(); + + if (we_are_destructing) + break; + + tasks[idx].is_being_processed = true; + task = tasks[idx]; + } + + std::exception_ptr eptr = nullptr; + try + { + // now do the task + if (task.bfp) + task.bfp(); + else if (task.mfp0) + task.mfp0(); + else if (task.mfp1) + task.mfp1(task.arg1); + else if (task.mfp2) + task.mfp2(task.arg1, task.arg2); + } + catch(...) + { + eptr = std::current_exception(); + } + + // Now let others know that we finished the task. We do this + // by clearing out the state of this task + { auto_mutex M(m); + tasks[idx].is_being_processed = false; + tasks[idx].task_id = 0; + tasks[idx].bfp.clear(); + tasks[idx].mfp0.clear(); + tasks[idx].mfp1.clear(); + tasks[idx].mfp2.clear(); + tasks[idx].arg1 = 0; + tasks[idx].arg2 = 0; + tasks[idx].eptr = eptr; + task_done_signaler.broadcast(); + } + + } + } + +// ---------------------------------------------------------------------------------------- + + long thread_pool_implementation:: + find_empty_task_slot ( + ) const + { + for (auto&& task : tasks) + task.propagate_exception(); + + for (unsigned long i = 0; i < tasks.size(); ++i) + { + if (tasks[i].is_empty()) + return i; + } + + return -1; + } + +// ---------------------------------------------------------------------------------------- + + long thread_pool_implementation:: + find_ready_task ( + ) const + { + for (unsigned long i = 0; i < tasks.size(); ++i) + { + if (tasks[i].is_ready()) + return i; + } + + return -1; + } + +// ---------------------------------------------------------------------------------------- + + uint64 thread_pool_implementation:: + make_next_task_id ( + long idx + ) + { + uint64 id = tasks[idx].next_task_id * tasks.size() + idx; + tasks[idx].next_task_id += 1; + return id; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long thread_pool_implementation:: + task_id_to_index ( + uint64 id + ) const + { + return static_cast<unsigned long>(id%tasks.size()); + } + +// ---------------------------------------------------------------------------------------- + + uint64 thread_pool_implementation:: + add_task_internal ( + const bfp_type& bfp, + std::shared_ptr<function_object_copy>& item + ) + { + auto_mutex M(m); + const thread_id_type my_thread_id = get_thread_id(); + + // find a thread that isn't doing anything + long idx = find_empty_task_slot(); + if (idx == -1 && is_worker_thread(my_thread_id)) + { + // this function is being called from within a worker thread and there + // aren't any other worker threads free so just perform the task right + // here + + M.unlock(); + bfp(); + + // return a task id that is both non-zero and also one + // that is never normally returned. This way calls + // to wait_for_task() will never block given this id. + return 1; + } + + // wait until there is a thread that isn't doing anything + while (idx == -1) + { + task_done_signaler.wait(); + idx = find_empty_task_slot(); + } + + tasks[idx].thread_id = my_thread_id; + tasks[idx].task_id = make_next_task_id(idx); + tasks[idx].bfp = bfp; + tasks[idx].function_copy.swap(item); + + task_ready_signaler.signal(); + + return tasks[idx].task_id; + } + +// ---------------------------------------------------------------------------------------- + + bool thread_pool_implementation:: + is_task_thread ( + ) const + { + auto_mutex M(m); + return is_worker_thread(get_thread_id()); + } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_THREAD_POOl_CPPh_ + |