summaryrefslogtreecommitdiffstats
path: root/third_party/rust/futures-executor/src/thread_pool.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/futures-executor/src/thread_pool.rs')
-rw-r--r--third_party/rust/futures-executor/src/thread_pool.rs380
1 files changed, 380 insertions, 0 deletions
diff --git a/third_party/rust/futures-executor/src/thread_pool.rs b/third_party/rust/futures-executor/src/thread_pool.rs
new file mode 100644
index 0000000000..5371008953
--- /dev/null
+++ b/third_party/rust/futures-executor/src/thread_pool.rs
@@ -0,0 +1,380 @@
+use crate::enter;
+use crate::unpark_mutex::UnparkMutex;
+use futures_core::future::Future;
+use futures_core::task::{Context, Poll};
+use futures_task::{waker_ref, ArcWake};
+use futures_task::{FutureObj, Spawn, SpawnError};
+use futures_util::future::FutureExt;
+use std::cmp;
+use std::fmt;
+use std::io;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::mpsc;
+use std::sync::{Arc, Mutex};
+use std::thread;
+
+/// A general-purpose thread pool for scheduling tasks that poll futures to
+/// completion.
+///
+/// The thread pool multiplexes any number of tasks onto a fixed number of
+/// worker threads.
+///
+/// This type is a clonable handle to the threadpool itself.
+/// Cloning it will only create a new reference, not a new threadpool.
+///
+/// This type is only available when the `thread-pool` feature of this
+/// library is activated.
+#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
+pub struct ThreadPool {
+ state: Arc<PoolState>,
+}
+
+/// Thread pool configuration object.
+///
+/// This type is only available when the `thread-pool` feature of this
+/// library is activated.
+#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
+pub struct ThreadPoolBuilder {
+ pool_size: usize,
+ stack_size: usize,
+ name_prefix: Option<String>,
+ after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
+ before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
+}
+
+trait AssertSendSync: Send + Sync {}
+impl AssertSendSync for ThreadPool {}
+
+struct PoolState {
+ tx: Mutex<mpsc::Sender<Message>>,
+ rx: Mutex<mpsc::Receiver<Message>>,
+ cnt: AtomicUsize,
+ size: usize,
+}
+
+impl fmt::Debug for ThreadPool {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
+ }
+}
+
+impl fmt::Debug for ThreadPoolBuilder {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ThreadPoolBuilder")
+ .field("pool_size", &self.pool_size)
+ .field("name_prefix", &self.name_prefix)
+ .finish()
+ }
+}
+
+enum Message {
+ Run(Task),
+ Close,
+}
+
+impl ThreadPool {
+ /// Creates a new thread pool with the default configuration.
+ ///
+ /// See documentation for the methods in
+ /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
+ /// configuration.
+ pub fn new() -> Result<Self, io::Error> {
+ ThreadPoolBuilder::new().create()
+ }
+
+ /// Create a default thread pool configuration, which can then be customized.
+ ///
+ /// See documentation for the methods in
+ /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
+ /// configuration.
+ pub fn builder() -> ThreadPoolBuilder {
+ ThreadPoolBuilder::new()
+ }
+
+ /// Spawns a future that will be run to completion.
+ ///
+ /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
+ /// > it is guaranteed to always succeed.
+ pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
+ let task = Task {
+ future,
+ wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
+ exec: self.clone(),
+ };
+ self.state.send(Message::Run(task));
+ }
+
+ /// Spawns a task that polls the given future with output `()` to
+ /// completion.
+ ///
+ /// ```
+ /// # {
+ /// use futures::executor::ThreadPool;
+ ///
+ /// let pool = ThreadPool::new().unwrap();
+ ///
+ /// let future = async { /* ... */ };
+ /// pool.spawn_ok(future);
+ /// # }
+ /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
+ /// ```
+ ///
+ /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
+ /// > it is guaranteed to always succeed.
+ pub fn spawn_ok<Fut>(&self, future: Fut)
+ where
+ Fut: Future<Output = ()> + Send + 'static,
+ {
+ self.spawn_obj_ok(FutureObj::new(Box::new(future)))
+ }
+}
+
+impl Spawn for ThreadPool {
+ fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
+ self.spawn_obj_ok(future);
+ Ok(())
+ }
+}
+
+impl PoolState {
+ fn send(&self, msg: Message) {
+ self.tx.lock().unwrap().send(msg).unwrap();
+ }
+
+ fn work(
+ &self,
+ idx: usize,
+ after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
+ before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
+ ) {
+ let _scope = enter().unwrap();
+ if let Some(after_start) = after_start {
+ after_start(idx);
+ }
+ loop {
+ let msg = self.rx.lock().unwrap().recv().unwrap();
+ match msg {
+ Message::Run(task) => task.run(),
+ Message::Close => break,
+ }
+ }
+ if let Some(before_stop) = before_stop {
+ before_stop(idx);
+ }
+ }
+}
+
+impl Clone for ThreadPool {
+ fn clone(&self) -> Self {
+ self.state.cnt.fetch_add(1, Ordering::Relaxed);
+ Self { state: self.state.clone() }
+ }
+}
+
+impl Drop for ThreadPool {
+ fn drop(&mut self) {
+ if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
+ for _ in 0..self.state.size {
+ self.state.send(Message::Close);
+ }
+ }
+ }
+}
+
+impl ThreadPoolBuilder {
+ /// Create a default thread pool configuration.
+ ///
+ /// See the other methods on this type for details on the defaults.
+ pub fn new() -> Self {
+ Self {
+ pool_size: cmp::max(1, num_cpus::get()),
+ stack_size: 0,
+ name_prefix: None,
+ after_start: None,
+ before_stop: None,
+ }
+ }
+
+ /// Set size of a future ThreadPool
+ ///
+ /// The size of a thread pool is the number of worker threads spawned. By
+ /// default, this is equal to the number of CPU cores.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `pool_size == 0`.
+ pub fn pool_size(&mut self, size: usize) -> &mut Self {
+ assert!(size > 0);
+ self.pool_size = size;
+ self
+ }
+
+ /// Set stack size of threads in the pool, in bytes.
+ ///
+ /// By default, worker threads use Rust's standard stack size.
+ pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
+ self.stack_size = stack_size;
+ self
+ }
+
+ /// Set thread name prefix of a future ThreadPool.
+ ///
+ /// Thread name prefix is used for generating thread names. For example, if prefix is
+ /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
+ ///
+ /// By default, worker threads are assigned Rust's standard thread name.
+ pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
+ self.name_prefix = Some(name_prefix.into());
+ self
+ }
+
+ /// Execute the closure `f` immediately after each worker thread is started,
+ /// but before running any tasks on it.
+ ///
+ /// This hook is intended for bookkeeping and monitoring.
+ /// The closure `f` will be dropped after the `builder` is dropped
+ /// and all worker threads in the pool have executed it.
+ ///
+ /// The closure provided will receive an index corresponding to the worker
+ /// thread it's running on.
+ pub fn after_start<F>(&mut self, f: F) -> &mut Self
+ where
+ F: Fn(usize) + Send + Sync + 'static,
+ {
+ self.after_start = Some(Arc::new(f));
+ self
+ }
+
+ /// Execute closure `f` just prior to shutting down each worker thread.
+ ///
+ /// This hook is intended for bookkeeping and monitoring.
+ /// The closure `f` will be dropped after the `builder` is dropped
+ /// and all threads in the pool have executed it.
+ ///
+ /// The closure provided will receive an index corresponding to the worker
+ /// thread it's running on.
+ pub fn before_stop<F>(&mut self, f: F) -> &mut Self
+ where
+ F: Fn(usize) + Send + Sync + 'static,
+ {
+ self.before_stop = Some(Arc::new(f));
+ self
+ }
+
+ /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
+ pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
+ let (tx, rx) = mpsc::channel();
+ let pool = ThreadPool {
+ state: Arc::new(PoolState {
+ tx: Mutex::new(tx),
+ rx: Mutex::new(rx),
+ cnt: AtomicUsize::new(1),
+ size: self.pool_size,
+ }),
+ };
+
+ for counter in 0..self.pool_size {
+ let state = pool.state.clone();
+ let after_start = self.after_start.clone();
+ let before_stop = self.before_stop.clone();
+ let mut thread_builder = thread::Builder::new();
+ if let Some(ref name_prefix) = self.name_prefix {
+ thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
+ }
+ if self.stack_size > 0 {
+ thread_builder = thread_builder.stack_size(self.stack_size);
+ }
+ thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
+ }
+ Ok(pool)
+ }
+}
+
+impl Default for ThreadPoolBuilder {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// A task responsible for polling a future to completion.
+struct Task {
+ future: FutureObj<'static, ()>,
+ exec: ThreadPool,
+ wake_handle: Arc<WakeHandle>,
+}
+
+struct WakeHandle {
+ mutex: UnparkMutex<Task>,
+ exec: ThreadPool,
+}
+
+impl Task {
+ /// Actually run the task (invoking `poll` on the future) on the current
+ /// thread.
+ fn run(self) {
+ let Self { mut future, wake_handle, mut exec } = self;
+ let waker = waker_ref(&wake_handle);
+ let mut cx = Context::from_waker(&waker);
+
+ // Safety: The ownership of this `Task` object is evidence that
+ // we are in the `POLLING`/`REPOLL` state for the mutex.
+ unsafe {
+ wake_handle.mutex.start_poll();
+
+ loop {
+ let res = future.poll_unpin(&mut cx);
+ match res {
+ Poll::Pending => {}
+ Poll::Ready(()) => return wake_handle.mutex.complete(),
+ }
+ let task = Self { future, wake_handle: wake_handle.clone(), exec };
+ match wake_handle.mutex.wait(task) {
+ Ok(()) => return, // we've waited
+ Err(task) => {
+ // someone's notified us
+ future = task.future;
+ exec = task.exec;
+ }
+ }
+ }
+ }
+ }
+}
+
+impl fmt::Debug for Task {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Task").field("contents", &"...").finish()
+ }
+}
+
+impl ArcWake for WakeHandle {
+ fn wake_by_ref(arc_self: &Arc<Self>) {
+ if let Ok(task) = arc_self.mutex.notify() {
+ arc_self.exec.state.send(Message::Run(task))
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::sync::mpsc;
+
+ #[test]
+ fn test_drop_after_start() {
+ {
+ let (tx, rx) = mpsc::sync_channel(2);
+ let _cpu_pool = ThreadPoolBuilder::new()
+ .pool_size(2)
+ .after_start(move |_| tx.send(1).unwrap())
+ .create()
+ .unwrap();
+
+ // After ThreadPoolBuilder is deconstructed, the tx should be dropped
+ // so that we can use rx as an iterator.
+ let count = rx.into_iter().count();
+ assert_eq!(count, 2);
+ }
+ std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
+ }
+}