summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/threads/thread_pool_extension.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/threads/thread_pool_extension.cpp')
-rw-r--r--ml/dlib/dlib/threads/thread_pool_extension.cpp347
1 files changed, 347 insertions, 0 deletions
diff --git a/ml/dlib/dlib/threads/thread_pool_extension.cpp b/ml/dlib/dlib/threads/thread_pool_extension.cpp
new file mode 100644
index 000000000..00d99b910
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_pool_extension.cpp
@@ -0,0 +1,347 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREAD_POOl_CPPh_
+#define DLIB_THREAD_POOl_CPPh_
+
+#include "thread_pool_extension.h"
+#include <memory>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool_implementation::
+ thread_pool_implementation (
+ unsigned long num_threads
+ ) :
+ task_done_signaler(m),
+ task_ready_signaler(m),
+ we_are_destructing(false)
+ {
+ tasks.resize(num_threads);
+ threads.resize(num_threads);
+ for (unsigned long i = 0; i < num_threads; ++i)
+ {
+ threads[i] = std::thread([&](){this->thread();});
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ shutdown_pool (
+ )
+ {
+ {
+ auto_mutex M(m);
+
+ // first wait for all pending tasks to finish
+ bool found_task = true;
+ while (found_task)
+ {
+ found_task = false;
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ // If task bucket i has a task that is currently supposed to be processed
+ if (tasks[i].is_empty() == false)
+ {
+ found_task = true;
+ break;
+ }
+ }
+
+ if (found_task)
+ task_done_signaler.wait();
+ }
+
+ // now tell the threads to kill themselves
+ we_are_destructing = true;
+ task_ready_signaler.broadcast();
+ }
+
+ // wait for all threads to terminate
+ for (auto& t : threads)
+ t.join();
+ threads.clear();
+
+ // Throw any unhandled exceptions. Since shutdown_pool() is only called in the
+ // destructor this will kill the program.
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool_implementation::
+ ~thread_pool_implementation()
+ {
+ shutdown_pool();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long thread_pool_implementation::
+ num_threads_in_pool (
+ ) const
+ {
+ auto_mutex M(m);
+ return tasks.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ wait_for_task (
+ uint64 task_id
+ ) const
+ {
+ auto_mutex M(m);
+ if (tasks.size() != 0)
+ {
+ const unsigned long idx = task_id_to_index(task_id);
+ while (tasks[idx].task_id == task_id)
+ task_done_signaler.wait();
+
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ wait_for_all_tasks (
+ ) const
+ {
+ const thread_id_type thread_id = get_thread_id();
+
+ auto_mutex M(m);
+ bool found_task = true;
+ while (found_task)
+ {
+ found_task = false;
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ // If task bucket i has a task that is currently supposed to be processed
+ // and it originated from the calling thread
+ if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id)
+ {
+ found_task = true;
+ break;
+ }
+ }
+
+ if (found_task)
+ task_done_signaler.wait();
+ }
+
+ // throw any exceptions generated by the tasks
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool thread_pool_implementation::
+ is_worker_thread (
+ const thread_id_type id
+ ) const
+ {
+ for (unsigned long i = 0; i < worker_thread_ids.size(); ++i)
+ {
+ if (worker_thread_ids[i] == id)
+ return true;
+ }
+
+ // if there aren't any threads in the pool then we consider all threads
+ // to be worker threads
+ if (tasks.size() == 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ thread (
+ )
+ {
+ {
+ // save the id of this worker thread into worker_thread_ids
+ auto_mutex M(m);
+ thread_id_type id = get_thread_id();
+ worker_thread_ids.push_back(id);
+ }
+
+ task_state_type task;
+ while (we_are_destructing == false)
+ {
+ long idx = 0;
+
+ // wait for a task to do
+ { auto_mutex M(m);
+ while ( (idx = find_ready_task()) == -1 && we_are_destructing == false)
+ task_ready_signaler.wait();
+
+ if (we_are_destructing)
+ break;
+
+ tasks[idx].is_being_processed = true;
+ task = tasks[idx];
+ }
+
+ std::exception_ptr eptr = nullptr;
+ try
+ {
+ // now do the task
+ if (task.bfp)
+ task.bfp();
+ else if (task.mfp0)
+ task.mfp0();
+ else if (task.mfp1)
+ task.mfp1(task.arg1);
+ else if (task.mfp2)
+ task.mfp2(task.arg1, task.arg2);
+ }
+ catch(...)
+ {
+ eptr = std::current_exception();
+ }
+
+ // Now let others know that we finished the task. We do this
+ // by clearing out the state of this task
+ { auto_mutex M(m);
+ tasks[idx].is_being_processed = false;
+ tasks[idx].task_id = 0;
+ tasks[idx].bfp.clear();
+ tasks[idx].mfp0.clear();
+ tasks[idx].mfp1.clear();
+ tasks[idx].mfp2.clear();
+ tasks[idx].arg1 = 0;
+ tasks[idx].arg2 = 0;
+ tasks[idx].eptr = eptr;
+ task_done_signaler.broadcast();
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long thread_pool_implementation::
+ find_empty_task_slot (
+ ) const
+ {
+ for (auto&& task : tasks)
+ task.propagate_exception();
+
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ if (tasks[i].is_empty())
+ return i;
+ }
+
+ return -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long thread_pool_implementation::
+ find_ready_task (
+ ) const
+ {
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ if (tasks[i].is_ready())
+ return i;
+ }
+
+ return -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 thread_pool_implementation::
+ make_next_task_id (
+ long idx
+ )
+ {
+ uint64 id = tasks[idx].next_task_id * tasks.size() + idx;
+ tasks[idx].next_task_id += 1;
+ return id;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long thread_pool_implementation::
+ task_id_to_index (
+ uint64 id
+ ) const
+ {
+ return static_cast<unsigned long>(id%tasks.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 thread_pool_implementation::
+ add_task_internal (
+ const bfp_type& bfp,
+ std::shared_ptr<function_object_copy>& item
+ )
+ {
+ auto_mutex M(m);
+ const thread_id_type my_thread_id = get_thread_id();
+
+ // find a thread that isn't doing anything
+ long idx = find_empty_task_slot();
+ if (idx == -1 && is_worker_thread(my_thread_id))
+ {
+ // this function is being called from within a worker thread and there
+ // aren't any other worker threads free so just perform the task right
+ // here
+
+ M.unlock();
+ bfp();
+
+ // return a task id that is both non-zero and also one
+ // that is never normally returned. This way calls
+ // to wait_for_task() will never block given this id.
+ return 1;
+ }
+
+ // wait until there is a thread that isn't doing anything
+ while (idx == -1)
+ {
+ task_done_signaler.wait();
+ idx = find_empty_task_slot();
+ }
+
+ tasks[idx].thread_id = my_thread_id;
+ tasks[idx].task_id = make_next_task_id(idx);
+ tasks[idx].bfp = bfp;
+ tasks[idx].function_copy.swap(item);
+
+ task_ready_signaler.signal();
+
+ return tasks[idx].task_id;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool thread_pool_implementation::
+ is_task_thread (
+ ) const
+ {
+ auto_mutex M(m);
+ return is_worker_thread(get_thread_id());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_THREAD_POOl_CPPh_
+