//! Thread pool for blocking operations use crate::loom::sync::{Arc, Condvar, Mutex}; use crate::loom::thread; use crate::runtime::blocking::schedule::BlockingSchedule; use crate::runtime::blocking::{shutdown, BlockingTask}; use crate::runtime::builder::ThreadNameFn; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{Builder, Callback, Handle}; use std::collections::{HashMap, VecDeque}; use std::fmt; use std::io; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; pub(crate) struct BlockingPool { spawner: Spawner, shutdown_rx: shutdown::Receiver, } #[derive(Clone)] pub(crate) struct Spawner { inner: Arc, } #[derive(Default)] pub(crate) struct SpawnerMetrics { num_threads: AtomicUsize, num_idle_threads: AtomicUsize, queue_depth: AtomicUsize, } impl SpawnerMetrics { fn num_threads(&self) -> usize { self.num_threads.load(Ordering::Relaxed) } fn num_idle_threads(&self) -> usize { self.num_idle_threads.load(Ordering::Relaxed) } cfg_metrics! { fn queue_depth(&self) -> usize { self.queue_depth.load(Ordering::Relaxed) } } fn inc_num_threads(&self) { self.num_threads.fetch_add(1, Ordering::Relaxed); } fn dec_num_threads(&self) { self.num_threads.fetch_sub(1, Ordering::Relaxed); } fn inc_num_idle_threads(&self) { self.num_idle_threads.fetch_add(1, Ordering::Relaxed); } fn dec_num_idle_threads(&self) -> usize { self.num_idle_threads.fetch_sub(1, Ordering::Relaxed) } fn inc_queue_depth(&self) { self.queue_depth.fetch_add(1, Ordering::Relaxed); } fn dec_queue_depth(&self) { self.queue_depth.fetch_sub(1, Ordering::Relaxed); } } struct Inner { /// State shared between worker threads. shared: Mutex, /// Pool threads wait on this. condvar: Condvar, /// Spawned threads use this name. thread_name: ThreadNameFn, /// Spawned thread stack size. stack_size: Option, /// Call after a thread starts. after_start: Option, /// Call before a thread stops. before_stop: Option, // Maximum number of threads. thread_cap: usize, // Customizable wait timeout. keep_alive: Duration, // Metrics about the pool. metrics: SpawnerMetrics, } struct Shared { queue: VecDeque, num_notify: u32, shutdown: bool, shutdown_tx: Option, /// Prior to shutdown, we clean up JoinHandles by having each timed-out /// thread join on the previous timed-out thread. This is not strictly /// necessary but helps avoid Valgrind false positives, see /// /// for more information. last_exiting_thread: Option>, /// This holds the JoinHandles for all running threads; on shutdown, the thread /// calling shutdown handles joining on these. worker_threads: HashMap>, /// This is a counter used to iterate worker_threads in a consistent order (for loom's /// benefit). worker_thread_index: usize, } pub(crate) struct Task { task: task::UnownedTask, mandatory: Mandatory, } #[derive(PartialEq, Eq)] pub(crate) enum Mandatory { #[cfg_attr(not(fs), allow(dead_code))] Mandatory, NonMandatory, } pub(crate) enum SpawnError { /// Pool is shutting down and the task was not scheduled ShuttingDown, /// There are no worker threads available to take the task /// and the OS failed to spawn a new one NoThreads(io::Error), } impl From for io::Error { fn from(e: SpawnError) -> Self { match e { SpawnError::ShuttingDown => { io::Error::new(io::ErrorKind::Other, "blocking pool shutting down") } SpawnError::NoThreads(e) => e, } } } impl Task { pub(crate) fn new(task: task::UnownedTask, mandatory: Mandatory) -> Task { Task { task, mandatory } } fn run(self) { self.task.run(); } fn shutdown_or_run_if_mandatory(self) { match self.mandatory { Mandatory::NonMandatory => self.task.shutdown(), Mandatory::Mandatory => self.task.run(), } } } const KEEP_ALIVE: Duration = Duration::from_secs(10); /// Runs the provided function on an executor dedicated to blocking operations. /// Tasks will be scheduled as non-mandatory, meaning they may not get executed /// in case of runtime shutdown. #[track_caller] #[cfg_attr(tokio_wasi, allow(dead_code))] pub(crate) fn spawn_blocking(func: F) -> JoinHandle where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { let rt = Handle::current(); rt.spawn_blocking(func) } cfg_fs! { #[cfg_attr(any( all(loom, not(test)), // the function is covered by loom tests test ), allow(dead_code))] /// Runs the provided function on an executor dedicated to blocking /// operations. Tasks will be scheduled as mandatory, meaning they are /// guaranteed to run unless a shutdown is already taking place. In case a /// shutdown is already taking place, `None` will be returned. pub(crate) fn spawn_mandatory_blocking(func: F) -> Option> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { let rt = Handle::current(); rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func) } } // ===== impl BlockingPool ===== impl BlockingPool { pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool { let (shutdown_tx, shutdown_rx) = shutdown::channel(); let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE); BlockingPool { spawner: Spawner { inner: Arc::new(Inner { shared: Mutex::new(Shared { queue: VecDeque::new(), num_notify: 0, shutdown: false, shutdown_tx: Some(shutdown_tx), last_exiting_thread: None, worker_threads: HashMap::new(), worker_thread_index: 0, }), condvar: Condvar::new(), thread_name: builder.thread_name.clone(), stack_size: builder.thread_stack_size, after_start: builder.after_start.clone(), before_stop: builder.before_stop.clone(), thread_cap, keep_alive, metrics: Default::default(), }), }, shutdown_rx, } } pub(crate) fn spawner(&self) -> &Spawner { &self.spawner } pub(crate) fn shutdown(&mut self, timeout: Option) { let mut shared = self.spawner.inner.shared.lock(); // The function can be called multiple times. First, by explicitly // calling `shutdown` then by the drop handler calling `shutdown`. This // prevents shutting down twice. if shared.shutdown { return; } shared.shutdown = true; shared.shutdown_tx = None; self.spawner.inner.condvar.notify_all(); let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread); let workers = std::mem::take(&mut shared.worker_threads); drop(shared); if self.shutdown_rx.wait(timeout) { let _ = last_exited_thread.map(|th| th.join()); // Loom requires that execution be deterministic, so sort by thread ID before joining. // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic) let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect(); workers.sort_by_key(|(id, _)| *id); for (_id, handle) in workers.into_iter() { let _ = handle.join(); } } } } impl Drop for BlockingPool { fn drop(&mut self) { self.shutdown(None); } } impl fmt::Debug for BlockingPool { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("BlockingPool").finish() } } // ===== impl Spawner ===== impl Spawner { #[track_caller] pub(crate) fn spawn_blocking(&self, rt: &Handle, func: F) -> JoinHandle where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { self.spawn_blocking_inner(Box::new(func), Mandatory::NonMandatory, None, rt) } else { self.spawn_blocking_inner(func, Mandatory::NonMandatory, None, rt) }; match spawn_result { Ok(()) => join_handle, // Compat: do not panic here, return the join_handle even though it will never resolve Err(SpawnError::ShuttingDown) => join_handle, Err(SpawnError::NoThreads(e)) => { panic!("OS can't spawn worker thread: {}", e) } } } cfg_fs! { #[track_caller] #[cfg_attr(any( all(loom, not(test)), // the function is covered by loom tests test ), allow(dead_code))] pub(crate) fn spawn_mandatory_blocking(&self, rt: &Handle, func: F) -> Option> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { self.spawn_blocking_inner( Box::new(func), Mandatory::Mandatory, None, rt, ) } else { self.spawn_blocking_inner( func, Mandatory::Mandatory, None, rt, ) }; if spawn_result.is_ok() { Some(join_handle) } else { None } } } #[track_caller] pub(crate) fn spawn_blocking_inner( &self, func: F, is_mandatory: Mandatory, name: Option<&str>, rt: &Handle, ) -> (JoinHandle, Result<(), SpawnError>) where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { let fut = BlockingTask::new(func); let id = task::Id::next(); #[cfg(all(tokio_unstable, feature = "tracing"))] let fut = { use tracing::Instrument; let location = std::panic::Location::caller(); let span = tracing::trace_span!( target: "tokio::task::blocking", "runtime.spawn", kind = %"blocking", task.name = %name.unwrap_or_default(), task.id = id.as_u64(), "fn" = %std::any::type_name::(), loc.file = location.file(), loc.line = location.line(), loc.col = location.column(), ); fut.instrument(span) }; #[cfg(not(all(tokio_unstable, feature = "tracing")))] let _ = name; let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id); let spawned = self.spawn_task(Task::new(task, is_mandatory), rt); (handle, spawned) } fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> { let mut shared = self.inner.shared.lock(); if shared.shutdown { // Shutdown the task: it's fine to shutdown this task (even if // mandatory) because it was scheduled after the shutdown of the // runtime began. task.task.shutdown(); // no need to even push this task; it would never get picked up return Err(SpawnError::ShuttingDown); } shared.queue.push_back(task); self.inner.metrics.inc_queue_depth(); if self.inner.metrics.num_idle_threads() == 0 { // No threads are able to process the task. if self.inner.metrics.num_threads() == self.inner.thread_cap { // At max number of threads } else { assert!(shared.shutdown_tx.is_some()); let shutdown_tx = shared.shutdown_tx.clone(); if let Some(shutdown_tx) = shutdown_tx { let id = shared.worker_thread_index; match self.spawn_thread(shutdown_tx, rt, id) { Ok(handle) => { self.inner.metrics.inc_num_threads(); shared.worker_thread_index += 1; shared.worker_threads.insert(id, handle); } Err(ref e) if is_temporary_os_thread_error(e) && self.inner.metrics.num_threads() > 0 => { // OS temporarily failed to spawn a new thread. // The task will be picked up eventually by a currently // busy thread. } Err(e) => { // The OS refused to spawn the thread and there is no thread // to pick up the task that has just been pushed to the queue. return Err(SpawnError::NoThreads(e)); } } } } } else { // Notify an idle worker thread. The notification counter // is used to count the needed amount of notifications // exactly. Thread libraries may generate spurious // wakeups, this counter is used to keep us in a // consistent state. self.inner.metrics.dec_num_idle_threads(); shared.num_notify += 1; self.inner.condvar.notify_one(); } Ok(()) } fn spawn_thread( &self, shutdown_tx: shutdown::Sender, rt: &Handle, id: usize, ) -> std::io::Result> { let mut builder = thread::Builder::new().name((self.inner.thread_name)()); if let Some(stack_size) = self.inner.stack_size { builder = builder.stack_size(stack_size); } let rt = rt.clone(); builder.spawn(move || { // Only the reference should be moved into the closure let _enter = rt.enter(); rt.inner.blocking_spawner().inner.run(id); drop(shutdown_tx); }) } } cfg_metrics! { impl Spawner { pub(crate) fn num_threads(&self) -> usize { self.inner.metrics.num_threads() } pub(crate) fn num_idle_threads(&self) -> usize { self.inner.metrics.num_idle_threads() } pub(crate) fn queue_depth(&self) -> usize { self.inner.metrics.queue_depth() } } } // Tells whether the error when spawning a thread is temporary. #[inline] fn is_temporary_os_thread_error(error: &std::io::Error) -> bool { matches!(error.kind(), std::io::ErrorKind::WouldBlock) } impl Inner { fn run(&self, worker_thread_id: usize) { if let Some(f) = &self.after_start { f() } let mut shared = self.shared.lock(); let mut join_on_thread = None; 'main: loop { // BUSY while let Some(task) = shared.queue.pop_front() { self.metrics.dec_queue_depth(); drop(shared); task.run(); shared = self.shared.lock(); } // IDLE self.metrics.inc_num_idle_threads(); while !shared.shutdown { let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap(); shared = lock_result.0; let timeout_result = lock_result.1; if shared.num_notify != 0 { // We have received a legitimate wakeup, // acknowledge it by decrementing the counter // and transition to the BUSY state. shared.num_notify -= 1; break; } // Even if the condvar "timed out", if the pool is entering the // shutdown phase, we want to perform the cleanup logic. if !shared.shutdown && timeout_result.timed_out() { // We'll join the prior timed-out thread's JoinHandle after dropping the lock. // This isn't done when shutting down, because the thread calling shutdown will // handle joining everything. let my_handle = shared.worker_threads.remove(&worker_thread_id); join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle); break 'main; } // Spurious wakeup detected, go back to sleep. } if shared.shutdown { // Drain the queue while let Some(task) = shared.queue.pop_front() { self.metrics.dec_queue_depth(); drop(shared); task.shutdown_or_run_if_mandatory(); shared = self.shared.lock(); } // Work was produced, and we "took" it (by decrementing num_notify). // This means that num_idle was decremented once for our wakeup. // But, since we are exiting, we need to "undo" that, as we'll stay idle. self.metrics.inc_num_idle_threads(); // NOTE: Technically we should also do num_notify++ and notify again, // but since we're shutting down anyway, that won't be necessary. break; } } // Thread exit self.metrics.dec_num_threads(); // num_idle should now be tracked exactly, panic // with a descriptive message if it is not the // case. let prev_idle = self.metrics.dec_num_idle_threads(); if prev_idle < self.metrics.num_idle_threads() { panic!("num_idle_threads underflowed on thread exit") } if shared.shutdown && self.metrics.num_threads() == 0 { self.condvar.notify_one(); } drop(shared); if let Some(f) = &self.before_stop { f() } if let Some(handle) = join_on_thread { let _ = handle.join(); } } } impl fmt::Debug for Spawner { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("blocking::Spawner").finish() } }