diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/tokio/src/runtime | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/tokio/src/runtime')
46 files changed, 11285 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/runtime/basic_scheduler.rs b/third_party/rust/tokio/src/runtime/basic_scheduler.rs new file mode 100644 index 0000000000..401f55b3f2 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/basic_scheduler.rs @@ -0,0 +1,574 @@ +use crate::future::poll_fn; +use crate::loom::sync::atomic::AtomicBool; +use crate::loom::sync::{Arc, Mutex}; +use crate::park::{Park, Unpark}; +use crate::runtime::context::EnterGuard; +use crate::runtime::driver::Driver; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::Callback; +use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; +use crate::sync::notify::Notify; +use crate::util::atomic_cell::AtomicCell; +use crate::util::{waker_ref, Wake, WakerRef}; + +use std::cell::RefCell; +use std::collections::VecDeque; +use std::fmt; +use std::future::Future; +use std::sync::atomic::Ordering::{AcqRel, Release}; +use std::task::Poll::{Pending, Ready}; +use std::time::Duration; + +/// Executes tasks on the current thread +pub(crate) struct BasicScheduler { + /// Core scheduler data is acquired by a thread entering `block_on`. + core: AtomicCell<Core>, + + /// Notifier for waking up other threads to steal the + /// driver. + notify: Notify, + + /// Sendable task spawner + spawner: Spawner, + + /// This is usually None, but right before dropping the BasicScheduler, it + /// is changed to `Some` with the context being the runtime's own context. + /// This ensures that any tasks dropped in the `BasicScheduler`s destructor + /// run in that runtime's context. + context_guard: Option<EnterGuard>, +} + +/// Data required for executing the scheduler. The struct is passed around to +/// a function that will perform the scheduling work and acts as a capability token. +struct Core { + /// Scheduler run queue + tasks: VecDeque<task::Notified<Arc<Shared>>>, + + /// Sendable task spawner + spawner: Spawner, + + /// Current tick + tick: u8, + + /// Runtime driver + /// + /// The driver is removed before starting to park the thread + driver: Option<Driver>, + + /// Metrics batch + metrics: MetricsBatch, +} + +#[derive(Clone)] +pub(crate) struct Spawner { + shared: Arc<Shared>, +} + +/// Scheduler state shared between threads. +struct Shared { + /// Remote run queue. None if the `Runtime` has been dropped. + queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>, + + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + + /// Unpark the blocked thread. + unpark: <Driver as Park>::Unpark, + + /// Indicates whether the blocked on thread was woken. + woken: AtomicBool, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Keeps track of various runtime metrics. + scheduler_metrics: SchedulerMetrics, + + /// This scheduler only has one worker. + worker_metrics: WorkerMetrics, +} + +/// Thread-local context. +struct Context { + /// Handle to the spawner + spawner: Spawner, + + /// Scheduler core, enabling the holder of `Context` to execute the + /// scheduler. + core: RefCell<Option<Box<Core>>>, +} + +/// Initial queue capacity. +const INITIAL_CAPACITY: usize = 64; + +/// Max number of tasks to poll per tick. +#[cfg(loom)] +const MAX_TASKS_PER_TICK: usize = 4; +#[cfg(not(loom))] +const MAX_TASKS_PER_TICK: usize = 61; + +/// How often to check the remote queue first. +const REMOTE_FIRST_INTERVAL: u8 = 31; + +// Tracks the current BasicScheduler. +scoped_thread_local!(static CURRENT: Context); + +impl BasicScheduler { + pub(crate) fn new( + driver: Driver, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> BasicScheduler { + let unpark = driver.unpark(); + + let spawner = Spawner { + shared: Arc::new(Shared { + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), + owned: OwnedTasks::new(), + unpark, + woken: AtomicBool::new(false), + before_park, + after_unpark, + scheduler_metrics: SchedulerMetrics::new(), + worker_metrics: WorkerMetrics::new(), + }), + }; + + let core = AtomicCell::new(Some(Box::new(Core { + tasks: VecDeque::with_capacity(INITIAL_CAPACITY), + spawner: spawner.clone(), + tick: 0, + driver: Some(driver), + metrics: MetricsBatch::new(), + }))); + + BasicScheduler { + core, + notify: Notify::new(), + spawner, + context_guard: None, + } + } + + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + pub(crate) fn block_on<F: Future>(&self, future: F) -> F::Output { + pin!(future); + + // Attempt to steal the scheduler core and block_on the future if we can + // there, otherwise, lets select on a notification that the core is + // available or the future is complete. + loop { + if let Some(core) = self.take_core() { + return core.block_on(future); + } else { + let mut enter = crate::runtime::enter(false); + + let notified = self.notify.notified(); + pin!(notified); + + if let Some(out) = enter + .block_on(poll_fn(|cx| { + if notified.as_mut().poll(cx).is_ready() { + return Ready(None); + } + + if let Ready(out) = future.as_mut().poll(cx) { + return Ready(Some(out)); + } + + Pending + })) + .expect("Failed to `Enter::block_on`") + { + return out; + } + } + } + } + + fn take_core(&self) -> Option<CoreGuard<'_>> { + let core = self.core.take()?; + + Some(CoreGuard { + context: Context { + spawner: self.spawner.clone(), + core: RefCell::new(Some(core)), + }, + basic_scheduler: self, + }) + } + + pub(super) fn set_context_guard(&mut self, guard: EnterGuard) { + self.context_guard = Some(guard); + } +} + +impl Drop for BasicScheduler { + fn drop(&mut self) { + // Avoid a double panic if we are currently panicking and + // the lock may be poisoned. + + let core = match self.take_core() { + Some(core) => core, + None if std::thread::panicking() => return, + None => panic!("Oh no! We never placed the Core back, this is a bug!"), + }; + + core.enter(|mut core, context| { + // Drain the OwnedTasks collection. This call also closes the + // collection, ensuring that no tasks are ever pushed after this + // call returns. + context.spawner.shared.owned.close_and_shutdown_all(); + + // Drain local queue + // We already shut down every task, so we just need to drop the task. + while let Some(task) = core.pop_task() { + drop(task); + } + + // Drain remote queue and set it to None + let remote_queue = core.spawner.shared.queue.lock().take(); + + // Using `Option::take` to replace the shared queue with `None`. + // We already shut down every task, so we just need to drop the task. + if let Some(remote_queue) = remote_queue { + for task in remote_queue { + drop(task); + } + } + + assert!(context.spawner.shared.owned.is_empty()); + + // Submit metrics + core.metrics.submit(&core.spawner.shared.worker_metrics); + + (core, ()) + }); + } +} + +impl fmt::Debug for BasicScheduler { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BasicScheduler").finish() + } +} + +// ===== impl Core ===== + +impl Core { + fn pop_task(&mut self) -> Option<task::Notified<Arc<Shared>>> { + let ret = self.tasks.pop_front(); + self.spawner + .shared + .worker_metrics + .set_queue_depth(self.tasks.len()); + ret + } + + fn push_task(&mut self, task: task::Notified<Arc<Shared>>) { + self.tasks.push_back(task); + self.metrics.inc_local_schedule_count(); + self.spawner + .shared + .worker_metrics + .set_queue_depth(self.tasks.len()); + } +} + +// ===== impl Context ===== + +impl Context { + /// Execute the closure with the given scheduler core stored in the + /// thread-local context. + fn run_task<R>(&self, mut core: Box<Core>, f: impl FnOnce() -> R) -> (Box<Core>, R) { + core.metrics.incr_poll_count(); + self.enter(core, || crate::coop::budget(f)) + } + + /// Blocks the current thread until an event is received by the driver, + /// including I/O events, timer events, ... + fn park(&self, mut core: Box<Core>) -> Box<Core> { + let mut driver = core.driver.take().expect("driver missing"); + + if let Some(f) = &self.spawner.shared.before_park { + // Incorrect lint, the closures are actually different types so `f` + // cannot be passed as an argument to `enter`. + #[allow(clippy::redundant_closure)] + let (c, _) = self.enter(core, || f()); + core = c; + } + + // This check will fail if `before_park` spawns a task for us to run + // instead of parking the thread + if core.tasks.is_empty() { + // Park until the thread is signaled + core.metrics.about_to_park(); + core.metrics.submit(&core.spawner.shared.worker_metrics); + + let (c, _) = self.enter(core, || { + driver.park().expect("failed to park"); + }); + + core = c; + core.metrics.returned_from_park(); + } + + if let Some(f) = &self.spawner.shared.after_unpark { + // Incorrect lint, the closures are actually different types so `f` + // cannot be passed as an argument to `enter`. + #[allow(clippy::redundant_closure)] + let (c, _) = self.enter(core, || f()); + core = c; + } + + core.driver = Some(driver); + core + } + + /// Checks the driver for new events without blocking the thread. + fn park_yield(&self, mut core: Box<Core>) -> Box<Core> { + let mut driver = core.driver.take().expect("driver missing"); + + core.metrics.submit(&core.spawner.shared.worker_metrics); + let (mut core, _) = self.enter(core, || { + driver + .park_timeout(Duration::from_millis(0)) + .expect("failed to park"); + }); + + core.driver = Some(driver); + core + } + + fn enter<R>(&self, core: Box<Core>, f: impl FnOnce() -> R) -> (Box<Core>, R) { + // Store the scheduler core in the thread-local context + // + // A drop-guard is employed at a higher level. + *self.core.borrow_mut() = Some(core); + + // Execute the closure while tracking the execution budget + let ret = f(); + + // Take the scheduler core back + let core = self.core.borrow_mut().take().expect("core missing"); + (core, ret) + } +} + +// ===== impl Spawner ===== + +impl Spawner { + /// Spawns a future onto the basic scheduler + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + + if let Some(notified) = notified { + self.shared.schedule(notified); + } + + handle + } + + fn pop(&self) -> Option<task::Notified<Arc<Shared>>> { + match self.shared.queue.lock().as_mut() { + Some(queue) => queue.pop_front(), + None => None, + } + } + + fn waker_ref(&self) -> WakerRef<'_> { + // Set woken to true when enter block_on, ensure outer future + // be polled for the first time when enter loop + self.shared.woken.store(true, Release); + waker_ref(&self.shared) + } + + // reset woken to false and return original value + pub(crate) fn reset_woken(&self) -> bool { + self.shared.woken.swap(false, AcqRel) + } +} + +cfg_metrics! { + impl Spawner { + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + &self.shared.scheduler_metrics + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + // TODO: avoid having to lock. The multi-threaded injection queue + // could probably be used here. + self.shared.queue.lock() + .as_ref() + .map(|queue| queue.len()) + .unwrap_or(0) + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + assert_eq!(0, worker); + &self.shared.worker_metrics + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Spawner").finish() + } +} + +// ===== impl Shared ===== + +impl Schedule for Arc<Shared> { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + self.owned.remove(task) + } + + fn schedule(&self, task: task::Notified<Self>) { + CURRENT.with(|maybe_cx| match maybe_cx { + Some(cx) if Arc::ptr_eq(self, &cx.spawner.shared) => { + let mut core = cx.core.borrow_mut(); + + // If `None`, the runtime is shutting down, so there is no need + // to schedule the task. + if let Some(core) = core.as_mut() { + core.push_task(task); + } + } + _ => { + // Track that a task was scheduled from **outside** of the runtime. + self.scheduler_metrics.inc_remote_schedule_count(); + + // If the queue is None, then the runtime has shut down. We + // don't need to do anything with the notification in that case. + let mut guard = self.queue.lock(); + if let Some(queue) = guard.as_mut() { + queue.push_back(task); + drop(guard); + self.unpark.unpark(); + } + } + }); + } +} + +impl Wake for Shared { + fn wake(arc_self: Arc<Self>) { + Wake::wake_by_ref(&arc_self) + } + + /// Wake by reference + fn wake_by_ref(arc_self: &Arc<Self>) { + arc_self.woken.store(true, Release); + arc_self.unpark.unpark(); + } +} + +// ===== CoreGuard ===== + +/// Used to ensure we always place the `Core` value back into its slot in +/// `BasicScheduler`, even if the future panics. +struct CoreGuard<'a> { + context: Context, + basic_scheduler: &'a BasicScheduler, +} + +impl CoreGuard<'_> { + fn block_on<F: Future>(self, future: F) -> F::Output { + self.enter(|mut core, context| { + let _enter = crate::runtime::enter(false); + let waker = context.spawner.waker_ref(); + let mut cx = std::task::Context::from_waker(&waker); + + pin!(future); + + 'outer: loop { + if core.spawner.reset_woken() { + let (c, res) = context.enter(core, || { + crate::coop::budget(|| future.as_mut().poll(&mut cx)) + }); + + core = c; + + if let Ready(v) = res { + return (core, v); + } + } + + for _ in 0..MAX_TASKS_PER_TICK { + // Get and increment the current tick + let tick = core.tick; + core.tick = core.tick.wrapping_add(1); + + let entry = if tick % REMOTE_FIRST_INTERVAL == 0 { + core.spawner.pop().or_else(|| core.tasks.pop_front()) + } else { + core.tasks.pop_front().or_else(|| core.spawner.pop()) + }; + + let task = match entry { + Some(entry) => entry, + None => { + core = context.park(core); + + // Try polling the `block_on` future next + continue 'outer; + } + }; + + let task = context.spawner.shared.owned.assert_owner(task); + + let (c, _) = context.run_task(core, || { + task.run(); + }); + + core = c; + } + + // Yield to the driver, this drives the timer and pulls any + // pending I/O events. + core = context.park_yield(core); + } + }) + } + + /// Enters the scheduler context. This sets the queue and other necessary + /// scheduler state in the thread-local. + fn enter<F, R>(self, f: F) -> R + where + F: FnOnce(Box<Core>, &Context) -> (Box<Core>, R), + { + // Remove `core` from `context` to pass into the closure. + let core = self.context.core.borrow_mut().take().expect("core missing"); + + // Call the closure and place `core` back + let (core, ret) = CURRENT.set(&self.context, || f(core, &self.context)); + + *self.context.core.borrow_mut() = Some(core); + + ret + } +} + +impl Drop for CoreGuard<'_> { + fn drop(&mut self) { + if let Some(core) = self.context.core.borrow_mut().take() { + // Replace old scheduler back into the state to allow + // other threads to pick it up and drive it. + self.basic_scheduler.core.set(core); + + // Wake up other possible threads that could steal the driver. + self.basic_scheduler.notify.notify_one() + } + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/mod.rs b/third_party/rust/tokio/src/runtime/blocking/mod.rs new file mode 100644 index 0000000000..15fe05c9ad --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/mod.rs @@ -0,0 +1,48 @@ +//! Abstracts out the APIs necessary to `Runtime` for integrating the blocking +//! pool. When the `blocking` feature flag is **not** enabled, these APIs are +//! shells. This isolates the complexity of dealing with conditional +//! compilation. + +mod pool; +pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task}; + +cfg_fs! { + pub(crate) use pool::spawn_mandatory_blocking; +} + +mod schedule; +mod shutdown; +mod task; +pub(crate) use schedule::NoopSchedule; +pub(crate) use task::BlockingTask; + +use crate::runtime::Builder; + +pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool { + BlockingPool::new(builder, thread_cap) +} + +/* +cfg_not_blocking_impl! { + use crate::runtime::Builder; + use std::time::Duration; + + #[derive(Debug, Clone)] + pub(crate) struct BlockingPool {} + + pub(crate) use BlockingPool as Spawner; + + pub(crate) fn create_blocking_pool(_builder: &Builder, _thread_cap: usize) -> BlockingPool { + BlockingPool {} + } + + impl BlockingPool { + pub(crate) fn spawner(&self) -> &BlockingPool { + self + } + + pub(crate) fn shutdown(&mut self, _duration: Option<Duration>) { + } + } +} +*/ diff --git a/third_party/rust/tokio/src/runtime/blocking/pool.rs b/third_party/rust/tokio/src/runtime/blocking/pool.rs new file mode 100644 index 0000000000..daf1f63fac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/pool.rs @@ -0,0 +1,396 @@ +//! Thread pool for blocking operations + +use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::thread; +use crate::runtime::blocking::schedule::NoopSchedule; +use crate::runtime::blocking::shutdown; +use crate::runtime::builder::ThreadNameFn; +use crate::runtime::context; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{Builder, Callback, Handle}; + +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::time::Duration; + +pub(crate) struct BlockingPool { + spawner: Spawner, + shutdown_rx: shutdown::Receiver, +} + +#[derive(Clone)] +pub(crate) struct Spawner { + inner: Arc<Inner>, +} + +struct Inner { + /// State shared between worker threads. + shared: Mutex<Shared>, + + /// Pool threads wait on this. + condvar: Condvar, + + /// Spawned threads use this name. + thread_name: ThreadNameFn, + + /// Spawned thread stack size. + stack_size: Option<usize>, + + /// Call after a thread starts. + after_start: Option<Callback>, + + /// Call before a thread stops. + before_stop: Option<Callback>, + + // Maximum number of threads. + thread_cap: usize, + + // Customizable wait timeout. + keep_alive: Duration, +} + +struct Shared { + queue: VecDeque<Task>, + num_th: usize, + num_idle: u32, + num_notify: u32, + shutdown: bool, + shutdown_tx: Option<shutdown::Sender>, + /// Prior to shutdown, we clean up JoinHandles by having each timed-out + /// thread join on the previous timed-out thread. This is not strictly + /// necessary but helps avoid Valgrind false positives, see + /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666> + /// for more information. + last_exiting_thread: Option<thread::JoinHandle<()>>, + /// This holds the JoinHandles for all running threads; on shutdown, the thread + /// calling shutdown handles joining on these. + worker_threads: HashMap<usize, thread::JoinHandle<()>>, + /// This is a counter used to iterate worker_threads in a consistent order (for loom's + /// benefit). + worker_thread_index: usize, +} + +pub(crate) struct Task { + task: task::UnownedTask<NoopSchedule>, + mandatory: Mandatory, +} + +#[derive(PartialEq, Eq)] +pub(crate) enum Mandatory { + #[cfg_attr(not(fs), allow(dead_code))] + Mandatory, + NonMandatory, +} + +impl Task { + pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task { + Task { task, mandatory } + } + + fn run(self) { + self.task.run(); + } + + fn shutdown_or_run_if_mandatory(self) { + match self.mandatory { + Mandatory::NonMandatory => self.task.shutdown(), + Mandatory::Mandatory => self.task.run(), + } + } +} + +const KEEP_ALIVE: Duration = Duration::from_secs(10); + +/// Runs the provided function on an executor dedicated to blocking operations. +/// Tasks will be scheduled as non-mandatory, meaning they may not get executed +/// in case of runtime shutdown. +pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let rt = context::current(); + rt.spawn_blocking(func) +} + +cfg_fs! { + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + /// Runs the provided function on an executor dedicated to blocking + /// operations. Tasks will be scheduled as mandatory, meaning they are + /// guaranteed to run unless a shutdown is already taking place. In case a + /// shutdown is already taking place, `None` will be returned. + pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let rt = context::current(); + rt.spawn_mandatory_blocking(func) + } +} + +// ===== impl BlockingPool ===== + +impl BlockingPool { + pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool { + let (shutdown_tx, shutdown_rx) = shutdown::channel(); + let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE); + + BlockingPool { + spawner: Spawner { + inner: Arc::new(Inner { + shared: Mutex::new(Shared { + queue: VecDeque::new(), + num_th: 0, + num_idle: 0, + num_notify: 0, + shutdown: false, + shutdown_tx: Some(shutdown_tx), + last_exiting_thread: None, + worker_threads: HashMap::new(), + worker_thread_index: 0, + }), + condvar: Condvar::new(), + thread_name: builder.thread_name.clone(), + stack_size: builder.thread_stack_size, + after_start: builder.after_start.clone(), + before_stop: builder.before_stop.clone(), + thread_cap, + keep_alive, + }), + }, + shutdown_rx, + } + } + + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) { + let mut shared = self.spawner.inner.shared.lock(); + + // The function can be called multiple times. First, by explicitly + // calling `shutdown` then by the drop handler calling `shutdown`. This + // prevents shutting down twice. + if shared.shutdown { + return; + } + + shared.shutdown = true; + shared.shutdown_tx = None; + self.spawner.inner.condvar.notify_all(); + + let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread); + let workers = std::mem::take(&mut shared.worker_threads); + + drop(shared); + + if self.shutdown_rx.wait(timeout) { + let _ = last_exited_thread.map(|th| th.join()); + + // Loom requires that execution be deterministic, so sort by thread ID before joining. + // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic) + let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect(); + workers.sort_by_key(|(id, _)| *id); + + for (_id, handle) in workers.into_iter() { + let _ = handle.join(); + } + } + } +} + +impl Drop for BlockingPool { + fn drop(&mut self) { + self.shutdown(None); + } +} + +impl fmt::Debug for BlockingPool { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BlockingPool").finish() + } +} + +// ===== impl Spawner ===== + +impl Spawner { + pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { + let mut shared = self.inner.shared.lock(); + + if shared.shutdown { + // Shutdown the task: it's fine to shutdown this task (even if + // mandatory) because it was scheduled after the shutdown of the + // runtime began. + task.task.shutdown(); + + // no need to even push this task; it would never get picked up + return Err(()); + } + + shared.queue.push_back(task); + + if shared.num_idle == 0 { + // No threads are able to process the task. + + if shared.num_th == self.inner.thread_cap { + // At max number of threads + } else { + shared.num_th += 1; + assert!(shared.shutdown_tx.is_some()); + let shutdown_tx = shared.shutdown_tx.clone(); + + if let Some(shutdown_tx) = shutdown_tx { + let id = shared.worker_thread_index; + shared.worker_thread_index += 1; + + let handle = self.spawn_thread(shutdown_tx, rt, id); + + shared.worker_threads.insert(id, handle); + } + } + } else { + // Notify an idle worker thread. The notification counter + // is used to count the needed amount of notifications + // exactly. Thread libraries may generate spurious + // wakeups, this counter is used to keep us in a + // consistent state. + shared.num_idle -= 1; + shared.num_notify += 1; + self.inner.condvar.notify_one(); + } + + Ok(()) + } + + fn spawn_thread( + &self, + shutdown_tx: shutdown::Sender, + rt: &Handle, + id: usize, + ) -> thread::JoinHandle<()> { + let mut builder = thread::Builder::new().name((self.inner.thread_name)()); + + if let Some(stack_size) = self.inner.stack_size { + builder = builder.stack_size(stack_size); + } + + let rt = rt.clone(); + + builder + .spawn(move || { + // Only the reference should be moved into the closure + let _enter = crate::runtime::context::enter(rt.clone()); + rt.blocking_spawner.inner.run(id); + drop(shutdown_tx); + }) + .expect("OS can't spawn a new worker thread") + } +} + +impl Inner { + fn run(&self, worker_thread_id: usize) { + if let Some(f) = &self.after_start { + f() + } + + let mut shared = self.shared.lock(); + let mut join_on_thread = None; + + 'main: loop { + // BUSY + while let Some(task) = shared.queue.pop_front() { + drop(shared); + task.run(); + + shared = self.shared.lock(); + } + + // IDLE + shared.num_idle += 1; + + while !shared.shutdown { + let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap(); + + shared = lock_result.0; + let timeout_result = lock_result.1; + + if shared.num_notify != 0 { + // We have received a legitimate wakeup, + // acknowledge it by decrementing the counter + // and transition to the BUSY state. + shared.num_notify -= 1; + break; + } + + // Even if the condvar "timed out", if the pool is entering the + // shutdown phase, we want to perform the cleanup logic. + if !shared.shutdown && timeout_result.timed_out() { + // We'll join the prior timed-out thread's JoinHandle after dropping the lock. + // This isn't done when shutting down, because the thread calling shutdown will + // handle joining everything. + let my_handle = shared.worker_threads.remove(&worker_thread_id); + join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle); + + break 'main; + } + + // Spurious wakeup detected, go back to sleep. + } + + if shared.shutdown { + // Drain the queue + while let Some(task) = shared.queue.pop_front() { + drop(shared); + + task.shutdown_or_run_if_mandatory(); + + shared = self.shared.lock(); + } + + // Work was produced, and we "took" it (by decrementing num_notify). + // This means that num_idle was decremented once for our wakeup. + // But, since we are exiting, we need to "undo" that, as we'll stay idle. + shared.num_idle += 1; + // NOTE: Technically we should also do num_notify++ and notify again, + // but since we're shutting down anyway, that won't be necessary. + break; + } + } + + // Thread exit + shared.num_th -= 1; + + // num_idle should now be tracked exactly, panic + // with a descriptive message if it is not the + // case. + shared.num_idle = shared + .num_idle + .checked_sub(1) + .expect("num_idle underflowed on thread exit"); + + if shared.shutdown && shared.num_th == 0 { + self.condvar.notify_one(); + } + + drop(shared); + + if let Some(f) = &self.before_stop { + f() + } + + if let Some(handle) = join_on_thread { + let _ = handle.join(); + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("blocking::Spawner").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/schedule.rs b/third_party/rust/tokio/src/runtime/blocking/schedule.rs new file mode 100644 index 0000000000..54252241d9 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/schedule.rs @@ -0,0 +1,19 @@ +use crate::runtime::task::{self, Task}; + +/// `task::Schedule` implementation that does nothing. This is unique to the +/// blocking scheduler as tasks scheduled are not really futures but blocking +/// operations. +/// +/// We avoid storing the task by forgetting it in `bind` and re-materializing it +/// in `release. +pub(crate) struct NoopSchedule; + +impl task::Schedule for NoopSchedule { + fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { + None + } + + fn schedule(&self, _task: task::Notified<Self>) { + unreachable!(); + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/shutdown.rs b/third_party/rust/tokio/src/runtime/blocking/shutdown.rs new file mode 100644 index 0000000000..e6f4674183 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/shutdown.rs @@ -0,0 +1,71 @@ +//! A shutdown channel. +//! +//! Each worker holds the `Sender` half. When all the `Sender` halves are +//! dropped, the `Receiver` receives a notification. + +use crate::loom::sync::Arc; +use crate::sync::oneshot; + +use std::time::Duration; + +#[derive(Debug, Clone)] +pub(super) struct Sender { + _tx: Arc<oneshot::Sender<()>>, +} + +#[derive(Debug)] +pub(super) struct Receiver { + rx: oneshot::Receiver<()>, +} + +pub(super) fn channel() -> (Sender, Receiver) { + let (tx, rx) = oneshot::channel(); + let tx = Sender { _tx: Arc::new(tx) }; + let rx = Receiver { rx }; + + (tx, rx) +} + +impl Receiver { + /// Blocks the current thread until all `Sender` handles drop. + /// + /// If `timeout` is `Some`, the thread is blocked for **at most** `timeout` + /// duration. If `timeout` is `None`, then the thread is blocked until the + /// shutdown signal is received. + /// + /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. + pub(crate) fn wait(&mut self, timeout: Option<Duration>) -> bool { + use crate::runtime::enter::try_enter; + + if timeout == Some(Duration::from_nanos(0)) { + return false; + } + + let mut e = match try_enter(false) { + Some(enter) => enter, + _ => { + if std::thread::panicking() { + // Don't panic in a panic + return false; + } else { + panic!( + "Cannot drop a runtime in a context where blocking is not allowed. \ + This happens when a runtime is dropped from within an asynchronous context." + ); + } + } + }; + + // The oneshot completes with an Err + // + // If blocking fails to wait, this indicates a problem parking the + // current thread (usually, shutting down a runtime stored in a + // thread-local). + if let Some(timeout) = timeout { + e.block_on_timeout(&mut self.rx, timeout).is_ok() + } else { + let _ = e.block_on(&mut self.rx); + true + } + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/task.rs b/third_party/rust/tokio/src/runtime/blocking/task.rs new file mode 100644 index 0000000000..0b7803a6c0 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/task.rs @@ -0,0 +1,44 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Converts a function to a future that completes on poll. +pub(crate) struct BlockingTask<T> { + func: Option<T>, +} + +impl<T> BlockingTask<T> { + /// Initializes a new blocking task from the given function. + pub(crate) fn new(func: T) -> BlockingTask<T> { + BlockingTask { func: Some(func) } + } +} + +// The closure `F` is never pinned +impl<T> Unpin for BlockingTask<T> {} + +impl<T, R> Future for BlockingTask<T> +where + T: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + type Output = R; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<R> { + let me = &mut *self; + let func = me + .func + .take() + .expect("[internal exception] blocking task ran twice."); + + // This is a little subtle: + // For convenience, we'd like _every_ call tokio ever makes to Task::poll() to be budgeted + // using coop. However, the way things are currently modeled, even running a blocking task + // currently goes through Task::poll(), and so is subject to budgeting. That isn't really + // what we want; a blocking task may itself want to run tasks (it might be a Worker!), so + // we want it to start without any budgeting. + crate::coop::stop(); + + Poll::Ready(func()) + } +} diff --git a/third_party/rust/tokio/src/runtime/builder.rs b/third_party/rust/tokio/src/runtime/builder.rs new file mode 100644 index 0000000000..91c365fd51 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/builder.rs @@ -0,0 +1,718 @@ +use crate::runtime::handle::Handle; +use crate::runtime::{blocking, driver, Callback, Runtime, Spawner}; + +use std::fmt; +use std::io; +use std::time::Duration; + +/// Builds Tokio Runtime with custom configuration values. +/// +/// Methods can be chained in order to set the configuration values. The +/// Runtime is constructed by calling [`build`]. +/// +/// New instances of `Builder` are obtained via [`Builder::new_multi_thread`] +/// or [`Builder::new_current_thread`]. +/// +/// See function level documentation for details on the various configuration +/// settings. +/// +/// [`build`]: method@Self::build +/// [`Builder::new_multi_thread`]: method@Self::new_multi_thread +/// [`Builder::new_current_thread`]: method@Self::new_current_thread +/// +/// # Examples +/// +/// ``` +/// use tokio::runtime::Builder; +/// +/// fn main() { +/// // build runtime +/// let runtime = Builder::new_multi_thread() +/// .worker_threads(4) +/// .thread_name("my-custom-name") +/// .thread_stack_size(3 * 1024 * 1024) +/// .build() +/// .unwrap(); +/// +/// // use runtime ... +/// } +/// ``` +pub struct Builder { + /// Runtime type + kind: Kind, + + /// Whether or not to enable the I/O driver + enable_io: bool, + + /// Whether or not to enable the time driver + enable_time: bool, + + /// Whether or not the clock should start paused. + start_paused: bool, + + /// The number of worker threads, used by Runtime. + /// + /// Only used when not using the current-thread executor. + worker_threads: Option<usize>, + + /// Cap on thread usage. + max_blocking_threads: usize, + + /// Name fn used for threads spawned by the runtime. + pub(super) thread_name: ThreadNameFn, + + /// Stack size used for threads spawned by the runtime. + pub(super) thread_stack_size: Option<usize>, + + /// Callback to run after each thread starts. + pub(super) after_start: Option<Callback>, + + /// To run before each worker thread stops + pub(super) before_stop: Option<Callback>, + + /// To run before each worker thread is parked. + pub(super) before_park: Option<Callback>, + + /// To run after each thread is unparked. + pub(super) after_unpark: Option<Callback>, + + /// Customizable keep alive timeout for BlockingPool + pub(super) keep_alive: Option<Duration>, +} + +pub(crate) type ThreadNameFn = std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>; + +pub(crate) enum Kind { + CurrentThread, + #[cfg(feature = "rt-multi-thread")] + MultiThread, +} + +impl Builder { + /// Returns a new builder with the current thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + /// + /// To spawn non-`Send` tasks on the resulting runtime, combine it with a + /// [`LocalSet`]. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn new_current_thread() -> Builder { + Builder::new(Kind::CurrentThread) + } + + /// Returns a new builder with the multi thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new_multi_thread() -> Builder { + Builder::new(Kind::MultiThread) + } + + /// Returns a new runtime builder initialized with default configuration + /// values. + /// + /// Configuration methods can be chained on the return value. + pub(crate) fn new(kind: Kind) -> Builder { + Builder { + kind, + + // I/O defaults to "off" + enable_io: false, + + // Time defaults to "off" + enable_time: false, + + // The clock starts not-paused + start_paused: false, + + // Default to lazy auto-detection (one thread per CPU core) + worker_threads: None, + + max_blocking_threads: 512, + + // Default thread name + thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()), + + // Do not set a stack size by default + thread_stack_size: None, + + // No worker thread callbacks + after_start: None, + before_stop: None, + before_park: None, + after_unpark: None, + + keep_alive: None, + } + } + + /// Enables both I/O and time drivers. + /// + /// Doing this is a shorthand for calling `enable_io` and `enable_time` + /// individually. If additional components are added to Tokio in the future, + /// `enable_all` will include these future components. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_all(&mut self) -> &mut Self { + #[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))] + self.enable_io(); + #[cfg(feature = "time")] + self.enable_time(); + + self + } + + /// Sets the number of worker threads the `Runtime` will use. + /// + /// This can be any number above 0 though it is advised to keep this value + /// on the smaller side. + /// + /// # Default + /// + /// The default value is the number of cores available to the system. + /// + /// # Panic + /// + /// When using the `current_thread` runtime this method will panic, since + /// those variants do not allow setting worker thread counts. + /// + /// + /// # Examples + /// + /// ## Multi threaded runtime with 4 threads + /// + /// ``` + /// use tokio::runtime; + /// + /// // This will spawn a work-stealing runtime with 4 worker threads. + /// let rt = runtime::Builder::new_multi_thread() + /// .worker_threads(4) + /// .build() + /// .unwrap(); + /// + /// rt.spawn(async move {}); + /// ``` + /// + /// ## Current thread runtime (will only run on the current thread via `Runtime::block_on`) + /// + /// ``` + /// use tokio::runtime; + /// + /// // Create a runtime that _must_ be driven from a call + /// // to `Runtime::block_on`. + /// let rt = runtime::Builder::new_current_thread() + /// .build() + /// .unwrap(); + /// + /// // This will run the runtime and future on the current thread + /// rt.block_on(async move {}); + /// ``` + /// + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + pub fn worker_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Worker threads cannot be set to 0"); + self.worker_threads = Some(val); + self + } + + /// Specifies the limit for additional threads spawned by the Runtime. + /// + /// These threads are used for blocking operations like tasks spawned + /// through [`spawn_blocking`]. Unlike the [`worker_threads`], they are not + /// always active and will exit if left idle for too long. You can change + /// this timeout duration with [`thread_keep_alive`]. + /// + /// The default value is 512. + /// + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + /// + /// # Upgrading from 0.x + /// + /// In old versions `max_threads` limited both blocking and worker threads, but the + /// current `max_blocking_threads` does not include async worker threads in the count. + /// + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking + /// [`worker_threads`]: Self::worker_threads + /// [`thread_keep_alive`]: Self::thread_keep_alive + #[cfg_attr(docsrs, doc(alias = "max_threads"))] + pub fn max_blocking_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Max blocking threads cannot be set to 0"); + self.max_blocking_threads = val; + self + } + + /// Sets name of threads spawned by the `Runtime`'s thread pool. + /// + /// The default name is "tokio-runtime-worker". + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_name("my-pool") + /// .build(); + /// # } + /// ``` + pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self { + let val = val.into(); + self.thread_name = std::sync::Arc::new(move || val.clone()); + self + } + + /// Sets a function used to generate the name of threads spawned by the `Runtime`'s thread pool. + /// + /// The default name fn is `|| "tokio-runtime-worker".into()`. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// # use std::sync::atomic::{AtomicUsize, Ordering}; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_name_fn(|| { + /// static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + /// let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + /// format!("my-pool-{}", id) + /// }) + /// .build(); + /// # } + /// ``` + pub fn thread_name_fn<F>(&mut self, f: F) -> &mut Self + where + F: Fn() -> String + Send + Sync + 'static, + { + self.thread_name = std::sync::Arc::new(f); + self + } + + /// Sets the stack size (in bytes) for worker threads. + /// + /// The actual stack size may be greater than this value if the platform + /// specifies minimal stack size. + /// + /// The default stack size for spawned threads is 2 MiB, though this + /// particular stack size is subject to change in the future. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_stack_size(32 * 1024) + /// .build(); + /// # } + /// ``` + pub fn thread_stack_size(&mut self, val: usize) -> &mut Self { + self.thread_stack_size = Some(val); + self + } + + /// Executes function `f` after each thread is started but before it starts + /// doing work. + /// + /// This is intended for bookkeeping and monitoring use cases. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_start(|| { + /// println!("thread started"); + /// }) + /// .build(); + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_start<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.after_start = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` before each thread stops. + /// + /// This is intended for bookkeeping and monitoring use cases. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_stop(|| { + /// println!("thread stopping"); + /// }) + /// .build(); + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_stop<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.before_stop = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` just before a thread is parked (goes idle). + /// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn) + /// can be called, and may result in this thread being unparked immediately. + /// + /// This can be used to start work only when the executor is idle, or for bookkeeping + /// and monitoring purposes. + /// + /// Note: There can only be one park callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ## Multithreaded executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_multi_thread() + /// .worker_threads(1) + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + /// ## Current thread executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_current_thread() + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.before_park = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` just after a thread unparks (starts executing tasks). + /// + /// This is intended for bookkeeping and monitoring use cases; note that work + /// in this callback will increase latencies when the application has allowed one or + /// more runtime threads to go idle. + /// + /// Note: There can only be one unpark callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_unpark(|| { + /// println!("thread unparking"); + /// }) + /// .build(); + /// + /// runtime.unwrap().block_on(async { + /// tokio::task::yield_now().await; + /// println!("Hello from Tokio!"); + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.after_unpark = Some(std::sync::Arc::new(f)); + self + } + + /// Creates the configured `Runtime`. + /// + /// The returned `Runtime` instance is ready to spawn tasks. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Builder; + /// + /// let rt = Builder::new_multi_thread().build().unwrap(); + /// + /// rt.block_on(async { + /// println!("Hello from the Tokio runtime"); + /// }); + /// ``` + pub fn build(&mut self) -> io::Result<Runtime> { + match &self.kind { + Kind::CurrentThread => self.build_basic_runtime(), + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => self.build_threaded_runtime(), + } + } + + fn get_cfg(&self) -> driver::Cfg { + driver::Cfg { + enable_pause_time: match self.kind { + Kind::CurrentThread => true, + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => false, + }, + enable_io: self.enable_io, + enable_time: self.enable_time, + start_paused: self.start_paused, + } + } + + /// Sets a custom timeout for a thread in the blocking pool. + /// + /// By default, the timeout for a thread is set to 10 seconds. This can + /// be overridden using .thread_keep_alive(). + /// + /// # Example + /// + /// ``` + /// # use tokio::runtime; + /// # use std::time::Duration; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_keep_alive(Duration::from_millis(100)) + /// .build(); + /// # } + /// ``` + pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self { + self.keep_alive = Some(duration); + self + } + + fn build_basic_runtime(&mut self) -> io::Result<Runtime> { + use crate::runtime::{BasicScheduler, Kind}; + + let (driver, resources) = driver::Driver::new(self.get_cfg())?; + + // And now put a single-threaded scheduler on top of the timer. When + // there are no futures ready to do something, it'll let the timer or + // the reactor to generate some new stimuli for the futures to continue + // in their life. + let scheduler = + BasicScheduler::new(driver, self.before_park.clone(), self.after_unpark.clone()); + let spawner = Spawner::Basic(scheduler.spawner().clone()); + + // Blocking pool + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); + let blocking_spawner = blocking_pool.spawner().clone(); + + Ok(Runtime { + kind: Kind::CurrentThread(scheduler), + handle: Handle { + spawner, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, + blocking_spawner, + }, + blocking_pool, + }) + } +} + +cfg_io_driver! { + impl Builder { + /// Enables the I/O driver. + /// + /// Doing this enables using net, process, signal, and some I/O types on + /// the runtime. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_io() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_io(&mut self) -> &mut Self { + self.enable_io = true; + self + } + } +} + +cfg_time! { + impl Builder { + /// Enables the time driver. + /// + /// Doing this enables using `tokio::time` on the runtime. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_time() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_time(&mut self) -> &mut Self { + self.enable_time = true; + self + } + } +} + +cfg_test_util! { + impl Builder { + /// Controls if the runtime's clock starts paused or advancing. + /// + /// Pausing time requires the current-thread runtime; construction of + /// the runtime will panic otherwise. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_current_thread() + /// .enable_time() + /// .start_paused(true) + /// .build() + /// .unwrap(); + /// ``` + pub fn start_paused(&mut self, start_paused: bool) -> &mut Self { + self.start_paused = start_paused; + self + } + } +} + +cfg_rt_multi_thread! { + impl Builder { + fn build_threaded_runtime(&mut self) -> io::Result<Runtime> { + use crate::loom::sys::num_cpus; + use crate::runtime::{Kind, ThreadPool}; + use crate::runtime::park::Parker; + + let core_threads = self.worker_threads.unwrap_or_else(num_cpus); + + let (driver, resources) = driver::Driver::new(self.get_cfg())?; + + let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver), self.before_park.clone(), self.after_unpark.clone()); + let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); + + // Create the blocking pool + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads + core_threads); + let blocking_spawner = blocking_pool.spawner().clone(); + + // Create the runtime handle + let handle = Handle { + spawner, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, + blocking_spawner, + }; + + // Spawn the thread pool workers + let _enter = crate::runtime::context::enter(handle.clone()); + launch.launch(); + + Ok(Runtime { + kind: Kind::ThreadPool(scheduler), + handle, + blocking_pool, + }) + } + } +} + +impl fmt::Debug for Builder { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Builder") + .field("worker_threads", &self.worker_threads) + .field("max_blocking_threads", &self.max_blocking_threads) + .field( + "thread_name", + &"<dyn Fn() -> String + Send + Sync + 'static>", + ) + .field("thread_stack_size", &self.thread_stack_size) + .field("after_start", &self.after_start.as_ref().map(|_| "...")) + .field("before_stop", &self.before_stop.as_ref().map(|_| "...")) + .field("before_park", &self.before_park.as_ref().map(|_| "...")) + .field("after_unpark", &self.after_unpark.as_ref().map(|_| "...")) + .finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/context.rs b/third_party/rust/tokio/src/runtime/context.rs new file mode 100644 index 0000000000..1f44a53402 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/context.rs @@ -0,0 +1,111 @@ +//! Thread local runtime context +use crate::runtime::{Handle, TryCurrentError}; + +use std::cell::RefCell; + +thread_local! { + static CONTEXT: RefCell<Option<Handle>> = RefCell::new(None) +} + +pub(crate) fn try_current() -> Result<Handle, crate::runtime::TryCurrentError> { + match CONTEXT.try_with(|ctx| ctx.borrow().clone()) { + Ok(Some(handle)) => Ok(handle), + Ok(None) => Err(TryCurrentError::new_no_context()), + Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()), + } +} + +pub(crate) fn current() -> Handle { + match try_current() { + Ok(handle) => handle, + Err(e) => panic!("{}", e), + } +} + +cfg_io_driver! { + pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).io_handle.clone() + }) { + Ok(io_handle) => io_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +cfg_signal_internal! { + #[cfg(unix)] + pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).signal_handle.clone() + }) { + Ok(signal_handle) => signal_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +cfg_time! { + pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).time_handle.clone() + }) { + Ok(time_handle) => time_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } + + cfg_test_util! { + pub(crate) fn clock() -> Option<crate::runtime::driver::Clock> { + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.clock.clone())) { + Ok(clock) => clock, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } + } +} + +cfg_rt! { + pub(crate) fn spawn_handle() -> Option<crate::runtime::Spawner> { + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.spawner.clone())) { + Ok(spawner) => spawner, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +/// Sets this [`Handle`] as the current active [`Handle`]. +/// +/// [`Handle`]: Handle +pub(crate) fn enter(new: Handle) -> EnterGuard { + match try_enter(new) { + Some(guard) => guard, + None => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } +} + +/// Sets this [`Handle`] as the current active [`Handle`]. +/// +/// [`Handle`]: Handle +pub(crate) fn try_enter(new: Handle) -> Option<EnterGuard> { + CONTEXT + .try_with(|ctx| { + let old = ctx.borrow_mut().replace(new); + EnterGuard(old) + }) + .ok() +} + +#[derive(Debug)] +pub(crate) struct EnterGuard(Option<Handle>); + +impl Drop for EnterGuard { + fn drop(&mut self) { + CONTEXT.with(|ctx| { + *ctx.borrow_mut() = self.0.take(); + }); + } +} diff --git a/third_party/rust/tokio/src/runtime/driver.rs b/third_party/rust/tokio/src/runtime/driver.rs new file mode 100644 index 0000000000..7e459779bb --- /dev/null +++ b/third_party/rust/tokio/src/runtime/driver.rs @@ -0,0 +1,208 @@ +//! Abstracts out the entire chain of runtime sub-drivers into common types. +use crate::park::thread::ParkThread; +use crate::park::Park; + +use std::io; +use std::time::Duration; + +// ===== io driver ===== + +cfg_io_driver! { + type IoDriver = crate::io::driver::Driver; + type IoStack = crate::park::either::Either<ProcessDriver, ParkThread>; + pub(crate) type IoHandle = Option<crate::io::driver::Handle>; + + fn create_io_stack(enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + use crate::park::either::Either; + + #[cfg(loom)] + assert!(!enabled); + + let ret = if enabled { + let io_driver = crate::io::driver::Driver::new()?; + let io_handle = io_driver.handle(); + + let (signal_driver, signal_handle) = create_signal_driver(io_driver)?; + let process_driver = create_process_driver(signal_driver); + + (Either::A(process_driver), Some(io_handle), signal_handle) + } else { + (Either::B(ParkThread::new()), Default::default(), Default::default()) + }; + + Ok(ret) + } +} + +cfg_not_io_driver! { + pub(crate) type IoHandle = (); + type IoStack = ParkThread; + + fn create_io_stack(_enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + Ok((ParkThread::new(), Default::default(), Default::default())) + } +} + +// ===== signal driver ===== + +macro_rules! cfg_signal_internal_and_unix { + ($($item:item)*) => { + #[cfg(unix)] + cfg_signal_internal! { $($item)* } + } +} + +cfg_signal_internal_and_unix! { + type SignalDriver = crate::signal::unix::driver::Driver; + pub(crate) type SignalHandle = Option<crate::signal::unix::driver::Handle>; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + let driver = crate::signal::unix::driver::Driver::new(io_driver)?; + let handle = driver.handle(); + Ok((driver, Some(handle))) + } +} + +cfg_not_signal_internal! { + pub(crate) type SignalHandle = (); + + cfg_io_driver! { + type SignalDriver = IoDriver; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + Ok((io_driver, ())) + } + } +} + +// ===== process driver ===== + +cfg_process_driver! { + type ProcessDriver = crate::process::unix::driver::Driver; + + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { + crate::process::unix::driver::Driver::new(signal_driver) + } +} + +cfg_not_process_driver! { + cfg_io_driver! { + type ProcessDriver = SignalDriver; + + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { + signal_driver + } + } +} + +// ===== time driver ===== + +cfg_time! { + type TimeDriver = crate::park::either::Either<crate::time::driver::Driver<IoStack>, IoStack>; + + pub(crate) type Clock = crate::time::Clock; + pub(crate) type TimeHandle = Option<crate::time::driver::Handle>; + + fn create_clock(enable_pausing: bool, start_paused: bool) -> Clock { + crate::time::Clock::new(enable_pausing, start_paused) + } + + fn create_time_driver( + enable: bool, + io_stack: IoStack, + clock: Clock, + ) -> (TimeDriver, TimeHandle) { + use crate::park::either::Either; + + if enable { + let driver = crate::time::driver::Driver::new(io_stack, clock); + let handle = driver.handle(); + + (Either::A(driver), Some(handle)) + } else { + (Either::B(io_stack), None) + } + } +} + +cfg_not_time! { + type TimeDriver = IoStack; + + pub(crate) type Clock = (); + pub(crate) type TimeHandle = (); + + fn create_clock(_enable_pausing: bool, _start_paused: bool) -> Clock { + () + } + + fn create_time_driver( + _enable: bool, + io_stack: IoStack, + _clock: Clock, + ) -> (TimeDriver, TimeHandle) { + (io_stack, ()) + } +} + +// ===== runtime driver ===== + +#[derive(Debug)] +pub(crate) struct Driver { + inner: TimeDriver, +} + +pub(crate) struct Resources { + pub(crate) io_handle: IoHandle, + pub(crate) signal_handle: SignalHandle, + pub(crate) time_handle: TimeHandle, + pub(crate) clock: Clock, +} + +pub(crate) struct Cfg { + pub(crate) enable_io: bool, + pub(crate) enable_time: bool, + pub(crate) enable_pause_time: bool, + pub(crate) start_paused: bool, +} + +impl Driver { + pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Resources)> { + let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io)?; + + let clock = create_clock(cfg.enable_pause_time, cfg.start_paused); + + let (time_driver, time_handle) = + create_time_driver(cfg.enable_time, io_stack, clock.clone()); + + Ok(( + Self { inner: time_driver }, + Resources { + io_handle, + signal_handle, + time_handle, + clock, + }, + )) + } +} + +impl Park for Driver { + type Unpark = <TimeDriver as Park>::Unpark; + type Error = <TimeDriver as Park>::Error; + + fn unpark(&self) -> Self::Unpark { + self.inner.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park() + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.inner.park_timeout(duration) + } + + fn shutdown(&mut self) { + self.inner.shutdown() + } +} diff --git a/third_party/rust/tokio/src/runtime/enter.rs b/third_party/rust/tokio/src/runtime/enter.rs new file mode 100644 index 0000000000..3f14cb5878 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/enter.rs @@ -0,0 +1,205 @@ +use std::cell::{Cell, RefCell}; +use std::fmt; +use std::marker::PhantomData; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum EnterContext { + #[cfg_attr(not(feature = "rt"), allow(dead_code))] + Entered { + allow_blocking: bool, + }, + NotEntered, +} + +impl EnterContext { + pub(crate) fn is_entered(self) -> bool { + matches!(self, EnterContext::Entered { .. }) + } +} + +thread_local!(static ENTERED: Cell<EnterContext> = Cell::new(EnterContext::NotEntered)); + +/// Represents an executor context. +pub(crate) struct Enter { + _p: PhantomData<RefCell<()>>, +} + +cfg_rt! { + use crate::park::thread::ParkError; + + use std::time::Duration; + + /// Marks the current thread as being within the dynamic extent of an + /// executor. + pub(crate) fn enter(allow_blocking: bool) -> Enter { + if let Some(enter) = try_enter(allow_blocking) { + return enter; + } + + panic!( + "Cannot start a runtime from within a runtime. This happens \ + because a function (like `block_on`) attempted to block the \ + current thread while the thread is being used to drive \ + asynchronous tasks." + ); + } + + /// Tries to enter a runtime context, returns `None` if already in a runtime + /// context. + pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + c.set(EnterContext::Entered { allow_blocking }); + Some(Enter { _p: PhantomData }) + } + }) + } +} + +// Forces the current "entered" state to be cleared while the closure +// is executed. +// +// # Warning +// +// This is hidden for a reason. Do not use without fully understanding +// executors. Misusing can easily cause your program to deadlock. +cfg_rt_multi_thread! { + pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { + // Reset in case the closure panics + struct Reset(EnterContext); + impl Drop for Reset { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); + }); + } + } + + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterContext::NotEntered); + e + }); + + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED + f() + } +} + +cfg_rt! { + /// Disallows blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { + let reset = ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: true, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: false, + }); + true + } else { + false + } + }); + DisallowBlockingGuard(reset) + } + + pub(crate) struct DisallowBlockingGuard(bool); + impl Drop for DisallowBlockingGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: false, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: true, + }); + } + }) + } + } + } +} + +cfg_rt_multi_thread! { + /// Returns true if in a runtime context. + pub(crate) fn context() -> EnterContext { + ENTERED.with(|c| c.get()) + } +} + +cfg_rt! { + impl Enter { + /// Blocks the thread on the specified future, returning the value with + /// which that future completes. + pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError> + where + F: std::future::Future, + { + use crate::park::thread::CachedParkThread; + + let mut park = CachedParkThread::new(); + park.block_on(f) + } + + /// Blocks the thread on the specified future for **at most** `timeout` + /// + /// If the future completes before `timeout`, the result is returned. If + /// `timeout` elapses, then `Err` is returned. + pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ParkError> + where + F: std::future::Future, + { + use crate::park::Park; + use crate::park::thread::CachedParkThread; + use std::task::Context; + use std::task::Poll::Ready; + use std::time::Instant; + + let mut park = CachedParkThread::new(); + let waker = park.get_unpark()?.into_waker(); + let mut cx = Context::from_waker(&waker); + + pin!(f); + let when = Instant::now() + timeout; + + loop { + if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { + return Ok(v); + } + + let now = Instant::now(); + + if now >= when { + return Err(()); + } + + park.park_timeout(when - now)?; + } + } + } +} + +impl fmt::Debug for Enter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Enter").finish() + } +} + +impl Drop for Enter { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(c.get().is_entered()); + c.set(EnterContext::NotEntered); + }); + } +} diff --git a/third_party/rust/tokio/src/runtime/handle.rs b/third_party/rust/tokio/src/runtime/handle.rs new file mode 100644 index 0000000000..9dbe6774dd --- /dev/null +++ b/third_party/rust/tokio/src/runtime/handle.rs @@ -0,0 +1,435 @@ +use crate::runtime::blocking::{BlockingTask, NoopSchedule}; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{blocking, context, driver, Spawner}; +use crate::util::error::{CONTEXT_MISSING_ERROR, THREAD_LOCAL_DESTROYED_ERROR}; + +use std::future::Future; +use std::marker::PhantomData; +use std::{error, fmt}; + +/// Handle to the runtime. +/// +/// The handle is internally reference-counted and can be freely cloned. A handle can be +/// obtained using the [`Runtime::handle`] method. +/// +/// [`Runtime::handle`]: crate::runtime::Runtime::handle() +#[derive(Debug, Clone)] +pub struct Handle { + pub(super) spawner: Spawner, + + /// Handles to the I/O drivers + #[cfg_attr( + not(any(feature = "net", feature = "process", all(unix, feature = "signal"))), + allow(dead_code) + )] + pub(super) io_handle: driver::IoHandle, + + /// Handles to the signal drivers + #[cfg_attr( + any( + loom, + not(all(unix, feature = "signal")), + not(all(unix, feature = "process")), + ), + allow(dead_code) + )] + pub(super) signal_handle: driver::SignalHandle, + + /// Handles to the time drivers + #[cfg_attr(not(feature = "time"), allow(dead_code))] + pub(super) time_handle: driver::TimeHandle, + + /// Source of `Instant::now()` + #[cfg_attr(not(all(feature = "time", feature = "test-util")), allow(dead_code))] + pub(super) clock: driver::Clock, + + /// Blocking pool spawner + pub(super) blocking_spawner: blocking::Spawner, +} + +/// Runtime context guard. +/// +/// Returned by [`Runtime::enter`] and [`Handle::enter`], the context guard exits +/// the runtime context on drop. +/// +/// [`Runtime::enter`]: fn@crate::runtime::Runtime::enter +#[derive(Debug)] +#[must_use = "Creating and dropping a guard does nothing"] +pub struct EnterGuard<'a> { + _guard: context::EnterGuard, + _handle_lifetime: PhantomData<&'a Handle>, +} + +impl Handle { + /// Enters the runtime context. This allows you to construct types that must + /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. + /// It will also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + pub fn enter(&self) -> EnterGuard<'_> { + EnterGuard { + _guard: context::enter(self.clone()), + _handle_lifetime: PhantomData, + } + } + + /// Returns a `Handle` view over the currently running `Runtime`. + /// + /// # Panic + /// + /// This will panic if called outside the context of a Tokio runtime. That means that you must + /// call this on one of the threads **being run by the runtime**. Calling this from within a + /// thread created by `std::thread::spawn` (for example) will cause a panic. + /// + /// # Examples + /// + /// This can be used to obtain the handle of the surrounding runtime from an async + /// block or function running on that runtime. + /// + /// ``` + /// # use std::thread; + /// # use tokio::runtime::Runtime; + /// # fn dox() { + /// # let rt = Runtime::new().unwrap(); + /// # rt.spawn(async { + /// use tokio::runtime::Handle; + /// + /// // Inside an async block or function. + /// let handle = Handle::current(); + /// handle.spawn(async { + /// println!("now running in the existing Runtime"); + /// }); + /// + /// # let handle = + /// thread::spawn(move || { + /// // Notice that the handle is created outside of this thread and then moved in + /// handle.spawn(async { /* ... */ }) + /// // This next line would cause a panic + /// // let handle2 = Handle::current(); + /// }); + /// # handle.join().unwrap(); + /// # }); + /// # } + /// ``` + pub fn current() -> Self { + context::current() + } + + /// Returns a Handle view over the currently running Runtime + /// + /// Returns an error if no Runtime has been started + /// + /// Contrary to `current`, this never panics + pub fn try_current() -> Result<Self, TryCurrentError> { + context::try_current() + } + + /// Spawns a future onto the Tokio runtime. + /// + /// This spawns the given future onto the runtime's executor, usually a + /// thread pool. The thread pool is then responsible for polling the future + /// until it completes. + /// + /// See [module level][mod] documentation for more details. + /// + /// [mod]: index.html + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Spawn a future onto the runtime using the handle + /// handle.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "task", None); + self.spawner.spawn(future) + } + + /// Runs the provided function on an executor dedicated to blocking. + /// operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Spawn a blocking function onto the runtime using the handle + /// handle.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + #[track_caller] + pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, _was_spawned) = + if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { + self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None) + } else { + self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None) + }; + + join_handle + } + + cfg_fs! { + #[track_caller] + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + pub(crate) fn spawn_mandatory_blocking<F, R>(&self, func: F) -> Option<JoinHandle<R>> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { + self.spawn_blocking_inner( + Box::new(func), + blocking::Mandatory::Mandatory, + None + ) + } else { + self.spawn_blocking_inner( + func, + blocking::Mandatory::Mandatory, + None + ) + }; + + if was_spawned { + Some(join_handle) + } else { + None + } + } + } + + #[track_caller] + pub(crate) fn spawn_blocking_inner<F, R>( + &self, + func: F, + is_mandatory: blocking::Mandatory, + name: Option<&str>, + ) -> (JoinHandle<R>, bool) + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let fut = BlockingTask::new(func); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let fut = { + use tracing::Instrument; + let location = std::panic::Location::caller(); + let span = tracing::trace_span!( + target: "tokio::task::blocking", + "runtime.spawn", + kind = %"blocking", + task.name = %name.unwrap_or_default(), + "fn" = %std::any::type_name::<F>(), + spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), + ); + fut.instrument(span) + }; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let _ = name; + + let (task, handle) = task::unowned(fut, NoopSchedule); + let spawned = self + .blocking_spawner + .spawn(blocking::Task::new(task, is_mandatory), self); + (handle, spawned.is_ok()) + } + + /// Runs a future to completion on this `Handle`'s associated `Runtime`. + /// + /// This runs the given future on the current thread, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers which + /// the future spawns internally will be executed on the runtime. + /// + /// When this is used on a `current_thread` runtime, only the + /// [`Runtime::block_on`] method can drive the IO and timer drivers, but the + /// `Handle::block_on` method cannot drive them. This means that, when using + /// this method on a current_thread runtime, anything that relies on IO or + /// timers will not work unless there is another thread currently calling + /// [`Runtime::block_on`] on the same runtime. + /// + /// # If the runtime has been shut down + /// + /// If the `Handle`'s associated `Runtime` has been shut down (through + /// [`Runtime::shutdown_background`], [`Runtime::shutdown_timeout`], or by + /// dropping it) and `Handle::block_on` is used it might return an error or + /// panic. Specifically IO resources will return an error and timers will + /// panic. Runtime independent futures will run as normal. + /// + /// # Panics + /// + /// This function panics if the provided future panics, if called within an + /// asynchronous execution context, or if a timer future is executed on a + /// runtime that has been shut down. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Execute the future, blocking the current thread until completion + /// handle.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// Or using `Handle::current`: + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main () { + /// let handle = Handle::current(); + /// std::thread::spawn(move || { + /// // Using Handle::block_on to run async code in the new thread. + /// handle.block_on(async { + /// println!("hello"); + /// }); + /// }); + /// } + /// ``` + /// + /// [`JoinError`]: struct@crate::task::JoinError + /// [`JoinHandle`]: struct@crate::task::JoinHandle + /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on + /// [`Runtime::shutdown_background`]: fn@crate::runtime::Runtime::shutdown_background + /// [`Runtime::shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// [`spawn_blocking`]: crate::task::spawn_blocking + /// [`tokio::fs`]: crate::fs + /// [`tokio::net`]: crate::net + /// [`tokio::time`]: crate::time + #[track_caller] + pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + + // Enter the **runtime** context. This configures spawning, the current I/O driver, ... + let _rt_enter = self.enter(); + + // Enter a **blocking** context. This prevents blocking from a runtime. + let mut blocking_enter = crate::runtime::enter(true); + + // Block on the future + blocking_enter + .block_on(future) + .expect("failed to park thread") + } + + pub(crate) fn shutdown(mut self) { + self.spawner.shutdown(); + } +} + +cfg_metrics! { + use crate::runtime::RuntimeMetrics; + + impl Handle { + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn metrics(&self) -> RuntimeMetrics { + RuntimeMetrics::new(self.clone()) + } + } +} + +/// Error returned by `try_current` when no Runtime has been started +#[derive(Debug)] +pub struct TryCurrentError { + kind: TryCurrentErrorKind, +} + +impl TryCurrentError { + pub(crate) fn new_no_context() -> Self { + Self { + kind: TryCurrentErrorKind::NoContext, + } + } + + pub(crate) fn new_thread_local_destroyed() -> Self { + Self { + kind: TryCurrentErrorKind::ThreadLocalDestroyed, + } + } + + /// Returns true if the call failed because there is currently no runtime in + /// the Tokio context. + pub fn is_missing_context(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::NoContext) + } + + /// Returns true if the call failed because the Tokio context thread-local + /// had been destroyed. This can usually only happen if in the destructor of + /// other thread-locals. + pub fn is_thread_local_destroyed(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::ThreadLocalDestroyed) + } +} + +enum TryCurrentErrorKind { + NoContext, + ThreadLocalDestroyed, +} + +impl fmt::Debug for TryCurrentErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TryCurrentErrorKind::*; + match self { + NoContext => f.write_str("NoContext"), + ThreadLocalDestroyed => f.write_str("ThreadLocalDestroyed"), + } + } +} + +impl fmt::Display for TryCurrentError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TryCurrentErrorKind::*; + match self.kind { + NoContext => f.write_str(CONTEXT_MISSING_ERROR), + ThreadLocalDestroyed => f.write_str(THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +impl error::Error for TryCurrentError {} diff --git a/third_party/rust/tokio/src/runtime/metrics/batch.rs b/third_party/rust/tokio/src/runtime/metrics/batch.rs new file mode 100644 index 0000000000..f1c3fa6b74 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/batch.rs @@ -0,0 +1,105 @@ +use crate::runtime::WorkerMetrics; + +use std::convert::TryFrom; +use std::sync::atomic::Ordering::Relaxed; +use std::time::Instant; + +pub(crate) struct MetricsBatch { + /// Number of times the worker parked. + park_count: u64, + + /// Number of times the worker woke w/o doing work. + noop_count: u64, + + /// Number of times stolen. + steal_count: u64, + + /// Number of tasks that were polled by the worker. + poll_count: u64, + + /// Number of tasks polled when the worker entered park. This is used to + /// track the noop count. + poll_count_on_last_park: u64, + + /// Number of tasks that were scheduled locally on this worker. + local_schedule_count: u64, + + /// Number of tasks moved to the global queue to make space in the local + /// queue + overflow_count: u64, + + /// The total busy duration in nanoseconds. + busy_duration_total: u64, + last_resume_time: Instant, +} + +impl MetricsBatch { + pub(crate) fn new() -> MetricsBatch { + MetricsBatch { + park_count: 0, + noop_count: 0, + steal_count: 0, + poll_count: 0, + poll_count_on_last_park: 0, + local_schedule_count: 0, + overflow_count: 0, + busy_duration_total: 0, + last_resume_time: Instant::now(), + } + } + + pub(crate) fn submit(&mut self, worker: &WorkerMetrics) { + worker.park_count.store(self.park_count, Relaxed); + worker.noop_count.store(self.noop_count, Relaxed); + worker.steal_count.store(self.steal_count, Relaxed); + worker.poll_count.store(self.poll_count, Relaxed); + + worker + .busy_duration_total + .store(self.busy_duration_total, Relaxed); + + worker + .local_schedule_count + .store(self.local_schedule_count, Relaxed); + worker.overflow_count.store(self.overflow_count, Relaxed); + } + + /// The worker is about to park. + pub(crate) fn about_to_park(&mut self) { + self.park_count += 1; + + if self.poll_count_on_last_park == self.poll_count { + self.noop_count += 1; + } else { + self.poll_count_on_last_park = self.poll_count; + } + + let busy_duration = self.last_resume_time.elapsed(); + let busy_duration = u64::try_from(busy_duration.as_nanos()).unwrap_or(u64::MAX); + self.busy_duration_total += busy_duration; + } + + pub(crate) fn returned_from_park(&mut self) { + self.last_resume_time = Instant::now(); + } + + pub(crate) fn inc_local_schedule_count(&mut self) { + self.local_schedule_count += 1; + } + + pub(crate) fn incr_poll_count(&mut self) { + self.poll_count += 1; + } +} + +cfg_rt_multi_thread! { + impl MetricsBatch { + pub(crate) fn incr_steal_count(&mut self, by: u16) { + self.steal_count += by as u64; + } + + pub(crate) fn incr_overflow_count(&mut self) { + self.overflow_count += 1; + } + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/mock.rs b/third_party/rust/tokio/src/runtime/metrics/mock.rs new file mode 100644 index 0000000000..6b9cf704f4 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/mock.rs @@ -0,0 +1,43 @@ +//! This file contains mocks of the types in src/runtime/metrics + +pub(crate) struct SchedulerMetrics {} + +pub(crate) struct WorkerMetrics {} + +pub(crate) struct MetricsBatch {} + +impl SchedulerMetrics { + pub(crate) fn new() -> Self { + Self {} + } + + /// Increment the number of tasks scheduled externally + pub(crate) fn inc_remote_schedule_count(&self) {} +} + +impl WorkerMetrics { + pub(crate) fn new() -> Self { + Self {} + } + + pub(crate) fn set_queue_depth(&self, _len: usize) {} +} + +impl MetricsBatch { + pub(crate) fn new() -> Self { + Self {} + } + + pub(crate) fn submit(&mut self, _to: &WorkerMetrics) {} + pub(crate) fn about_to_park(&mut self) {} + pub(crate) fn returned_from_park(&mut self) {} + pub(crate) fn incr_poll_count(&mut self) {} + pub(crate) fn inc_local_schedule_count(&mut self) {} +} + +cfg_rt_multi_thread! { + impl MetricsBatch { + pub(crate) fn incr_steal_count(&mut self, _by: u16) {} + pub(crate) fn incr_overflow_count(&mut self) {} + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/mod.rs b/third_party/rust/tokio/src/runtime/metrics/mod.rs new file mode 100644 index 0000000000..ca643a5904 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/mod.rs @@ -0,0 +1,30 @@ +//! This module contains information need to view information about how the +//! runtime is performing. +//! +//! **Note**: This is an [unstable API][unstable]. The public API of types in +//! this module may break in 1.x releases. See [the documentation on unstable +//! features][unstable] for details. +//! +//! [unstable]: crate#unstable-features +#![allow(clippy::module_inception)] + +cfg_metrics! { + mod batch; + pub(crate) use batch::MetricsBatch; + + mod runtime; + #[allow(unreachable_pub)] // rust-lang/rust#57411 + pub use runtime::RuntimeMetrics; + + mod scheduler; + pub(crate) use scheduler::SchedulerMetrics; + + mod worker; + pub(crate) use worker::WorkerMetrics; +} + +cfg_not_metrics! { + mod mock; + + pub(crate) use mock::{SchedulerMetrics, WorkerMetrics, MetricsBatch}; +} diff --git a/third_party/rust/tokio/src/runtime/metrics/runtime.rs b/third_party/rust/tokio/src/runtime/metrics/runtime.rs new file mode 100644 index 0000000000..0f8055907f --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/runtime.rs @@ -0,0 +1,449 @@ +use crate::runtime::Handle; + +use std::sync::atomic::Ordering::Relaxed; +use std::time::Duration; + +/// Handle to the runtime's metrics. +/// +/// This handle is internally reference-counted and can be freely cloned. A +/// `RuntimeMetrics` handle is obtained using the [`Runtime::metrics`] method. +/// +/// [`Runtime::metrics`]: crate::runtime::Runtime::metrics() +#[derive(Clone, Debug)] +pub struct RuntimeMetrics { + handle: Handle, +} + +impl RuntimeMetrics { + pub(crate) fn new(handle: Handle) -> RuntimeMetrics { + RuntimeMetrics { handle } + } + + /// Returns the number of worker threads used by the runtime. + /// + /// The number of workers is set by configuring `worker_threads` on + /// `runtime::Builder`. When using the `current_thread` runtime, the return + /// value is always `1`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.num_workers(); + /// println!("Runtime is using {} workers", n); + /// } + /// ``` + pub fn num_workers(&self) -> usize { + self.handle.spawner.num_workers() + } + + /// Returns the number of tasks scheduled from **outside** of the runtime. + /// + /// The remote schedule count starts at zero when the runtime is created and + /// increases by one each time a task is woken from **outside** of the + /// runtime. This usually means that a task is spawned or notified from a + /// non-runtime thread and must be queued using the Runtime's injection + /// queue, which tends to be slower. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.remote_schedule_count(); + /// println!("{} tasks were scheduled from outside the runtime", n); + /// } + /// ``` + pub fn remote_schedule_count(&self) -> u64 { + self.handle + .spawner + .scheduler_metrics() + .remote_schedule_count + .load(Relaxed) + } + + /// Returns the total number of times the given worker thread has parked. + /// + /// The worker park count starts at zero when the runtime is created and + /// increases by one each time the worker parks the thread waiting for new + /// inbound events to process. This usually means the worker has processed + /// all pending work and is currently idle. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_park_count(0); + /// println!("worker 0 parked {} times", n); + /// } + /// ``` + pub fn worker_park_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .park_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread unparked but + /// performed no work before parking again. + /// + /// The worker no-op count starts at zero when the runtime is created and + /// increases by one each time the worker unparks the thread but finds no + /// new work and goes back to sleep. This indicates a false-positive wake up. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_noop_count(0); + /// println!("worker 0 had {} no-op unparks", n); + /// } + /// ``` + pub fn worker_noop_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .noop_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread stole tasks from + /// another worker thread. + /// + /// This metric only applies to the **multi-threaded** runtime and will always return `0` when using the current thread runtime. + /// + /// The worker steal count starts at zero when the runtime is created and + /// increases by one each time the worker has processed its scheduled queue + /// and successfully steals more pending tasks from another worker. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_noop_count(0); + /// println!("worker 0 has stolen tasks {} times", n); + /// } + /// ``` + pub fn worker_steal_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .steal_count + .load(Relaxed) + } + + /// Returns the number of tasks the given worker thread has polled. + /// + /// The worker poll count starts at zero when the runtime is created and + /// increases by one each time the worker polls a scheduled task. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_poll_count(0); + /// println!("worker 0 has polled {} tasks", n); + /// } + /// ``` + pub fn worker_poll_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .poll_count + .load(Relaxed) + } + + /// Returns the amount of time the given worker thread has been busy. + /// + /// The worker busy duration starts at zero when the runtime is created and + /// increases whenever the worker is spending time processing work. Using + /// this value can indicate the load of the given worker. If a lot of time + /// is spent busy, then the worker is under load and will check for inbound + /// events less often. + /// + /// The timer is monotonically increasing. It is never decremented or reset + /// to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_poll_count(0); + /// println!("worker 0 has polled {} tasks", n); + /// } + /// ``` + pub fn worker_total_busy_duration(&self, worker: usize) -> Duration { + let nanos = self + .handle + .spawner + .worker_metrics(worker) + .busy_duration_total + .load(Relaxed); + Duration::from_nanos(nanos) + } + + /// Returns the number of tasks scheduled from **within** the runtime on the + /// given worker's local queue. + /// + /// The local schedule count starts at zero when the runtime is created and + /// increases by one each time a task is woken from **inside** of the + /// runtime on the given worker. This usually means that a task is spawned + /// or notified from within a runtime thread and will be queued on the + /// worker-local queue. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_local_schedule_count(0); + /// println!("{} tasks were scheduled on the worker's local queue", n); + /// } + /// ``` + pub fn worker_local_schedule_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .local_schedule_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread saturated its local + /// queue. + /// + /// This metric only applies to the **multi-threaded** scheduler. + /// + /// The worker steal count starts at zero when the runtime is created and + /// increases by one each time the worker attempts to schedule a task + /// locally, but its local queue is full. When this happens, half of the + /// local queue is moved to the injection queue. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_overflow_count(0); + /// println!("worker 0 has overflowed its queue {} times", n); + /// } + /// ``` + pub fn worker_overflow_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .overflow_count + .load(Relaxed) + } + + /// Returns the number of tasks currently scheduled in the runtime's + /// injection queue. + /// + /// Tasks that are spanwed or notified from a non-runtime thread are + /// scheduled using the runtime's injection queue. This metric returns the + /// **current** number of tasks pending in the injection queue. As such, the + /// returned value may increase or decrease as new tasks are scheduled and + /// processed. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.injection_queue_depth(); + /// println!("{} tasks currently pending in the runtime's injection queue", n); + /// } + /// ``` + pub fn injection_queue_depth(&self) -> usize { + self.handle.spawner.injection_queue_depth() + } + + /// Returns the number of tasks currently scheduled in the given worker's + /// local queue. + /// + /// Tasks that are spawned or notified from within a runtime thread are + /// scheduled using that worker's local queue. This metric returns the + /// **current** number of tasks pending in the worker's local queue. As + /// such, the returned value may increase or decrease as new tasks are + /// scheduled and processed. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_local_queue_depth(0); + /// println!("{} tasks currently pending in worker 0's local queue", n); + /// } + /// ``` + pub fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.handle.spawner.worker_local_queue_depth(worker) + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/scheduler.rs b/third_party/rust/tokio/src/runtime/metrics/scheduler.rs new file mode 100644 index 0000000000..d1ba3b6442 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/scheduler.rs @@ -0,0 +1,27 @@ +use crate::loom::sync::atomic::{AtomicU64, Ordering::Relaxed}; + +/// Retrieves metrics from the Tokio runtime. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[derive(Debug)] +pub(crate) struct SchedulerMetrics { + /// Number of tasks that are scheduled from outside the runtime. + pub(super) remote_schedule_count: AtomicU64, +} + +impl SchedulerMetrics { + pub(crate) fn new() -> SchedulerMetrics { + SchedulerMetrics { + remote_schedule_count: AtomicU64::new(0), + } + } + + /// Increment the number of tasks scheduled externally + pub(crate) fn inc_remote_schedule_count(&self) { + self.remote_schedule_count.fetch_add(1, Relaxed); + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/worker.rs b/third_party/rust/tokio/src/runtime/metrics/worker.rs new file mode 100644 index 0000000000..c9b85e48e4 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/worker.rs @@ -0,0 +1,61 @@ +use crate::loom::sync::atomic::Ordering::Relaxed; +use crate::loom::sync::atomic::{AtomicU64, AtomicUsize}; + +/// Retreive runtime worker metrics. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[derive(Debug)] +#[repr(align(128))] +pub(crate) struct WorkerMetrics { + /// Number of times the worker parked. + pub(crate) park_count: AtomicU64, + + /// Number of times the worker woke then parked again without doing work. + pub(crate) noop_count: AtomicU64, + + /// Number of times the worker attempted to steal. + pub(crate) steal_count: AtomicU64, + + /// Number of tasks the worker polled. + pub(crate) poll_count: AtomicU64, + + /// Amount of time the worker spent doing work vs. parking. + pub(crate) busy_duration_total: AtomicU64, + + /// Number of tasks scheduled for execution on the worker's local queue. + pub(crate) local_schedule_count: AtomicU64, + + /// Number of tasks moved from the local queue to the global queue to free space. + pub(crate) overflow_count: AtomicU64, + + /// Number of tasks currently in the local queue. Used only by the + /// current-thread scheduler. + pub(crate) queue_depth: AtomicUsize, +} + +impl WorkerMetrics { + pub(crate) fn new() -> WorkerMetrics { + WorkerMetrics { + park_count: AtomicU64::new(0), + noop_count: AtomicU64::new(0), + steal_count: AtomicU64::new(0), + poll_count: AtomicU64::new(0), + overflow_count: AtomicU64::new(0), + busy_duration_total: AtomicU64::new(0), + local_schedule_count: AtomicU64::new(0), + queue_depth: AtomicUsize::new(0), + } + } + + pub(crate) fn queue_depth(&self) -> usize { + self.queue_depth.load(Relaxed) + } + + pub(crate) fn set_queue_depth(&self, len: usize) { + self.queue_depth.store(len, Relaxed); + } +} diff --git a/third_party/rust/tokio/src/runtime/mod.rs b/third_party/rust/tokio/src/runtime/mod.rs new file mode 100644 index 0000000000..7c381b0bbd --- /dev/null +++ b/third_party/rust/tokio/src/runtime/mod.rs @@ -0,0 +1,623 @@ +//! The Tokio runtime. +//! +//! Unlike other Rust programs, asynchronous applications require runtime +//! support. In particular, the following runtime services are necessary: +//! +//! * An **I/O event loop**, called the driver, which drives I/O resources and +//! dispatches I/O events to tasks that depend on them. +//! * A **scheduler** to execute [tasks] that use these I/O resources. +//! * A **timer** for scheduling work to run after a set period of time. +//! +//! Tokio's [`Runtime`] bundles all of these services as a single type, allowing +//! them to be started, shut down, and configured together. However, often it is +//! not required to configure a [`Runtime`] manually, and a user may just use the +//! [`tokio::main`] attribute macro, which creates a [`Runtime`] under the hood. +//! +//! # Usage +//! +//! When no fine tuning is required, the [`tokio::main`] attribute macro can be +//! used. +//! +//! ```no_run +//! use tokio::net::TcpListener; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; +//! +//! loop { +//! let (mut socket, _) = listener.accept().await?; +//! +//! tokio::spawn(async move { +//! let mut buf = [0; 1024]; +//! +//! // In a loop, read data from the socket and write the data back. +//! loop { +//! let n = match socket.read(&mut buf).await { +//! // socket closed +//! Ok(n) if n == 0 => return, +//! Ok(n) => n, +//! Err(e) => { +//! println!("failed to read from socket; err = {:?}", e); +//! return; +//! } +//! }; +//! +//! // Write the data back +//! if let Err(e) = socket.write_all(&buf[0..n]).await { +//! println!("failed to write to socket; err = {:?}", e); +//! return; +//! } +//! } +//! }); +//! } +//! } +//! ``` +//! +//! From within the context of the runtime, additional tasks are spawned using +//! the [`tokio::spawn`] function. Futures spawned using this function will be +//! executed on the same thread pool used by the [`Runtime`]. +//! +//! A [`Runtime`] instance can also be used directly. +//! +//! ```no_run +//! use tokio::net::TcpListener; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; +//! use tokio::runtime::Runtime; +//! +//! fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Create the runtime +//! let rt = Runtime::new()?; +//! +//! // Spawn the root task +//! rt.block_on(async { +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; +//! +//! loop { +//! let (mut socket, _) = listener.accept().await?; +//! +//! tokio::spawn(async move { +//! let mut buf = [0; 1024]; +//! +//! // In a loop, read data from the socket and write the data back. +//! loop { +//! let n = match socket.read(&mut buf).await { +//! // socket closed +//! Ok(n) if n == 0 => return, +//! Ok(n) => n, +//! Err(e) => { +//! println!("failed to read from socket; err = {:?}", e); +//! return; +//! } +//! }; +//! +//! // Write the data back +//! if let Err(e) = socket.write_all(&buf[0..n]).await { +//! println!("failed to write to socket; err = {:?}", e); +//! return; +//! } +//! } +//! }); +//! } +//! }) +//! } +//! ``` +//! +//! ## Runtime Configurations +//! +//! Tokio provides multiple task scheduling strategies, suitable for different +//! applications. The [runtime builder] or `#[tokio::main]` attribute may be +//! used to select which scheduler to use. +//! +//! #### Multi-Thread Scheduler +//! +//! The multi-thread scheduler executes futures on a _thread pool_, using a +//! work-stealing strategy. By default, it will start a worker thread for each +//! CPU core available on the system. This tends to be the ideal configuration +//! for most applications. The multi-thread scheduler requires the `rt-multi-thread` +//! feature flag, and is selected by default: +//! ``` +//! use tokio::runtime; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let threaded_rt = runtime::Runtime::new()?; +//! # Ok(()) } +//! ``` +//! +//! Most applications should use the multi-thread scheduler, except in some +//! niche use-cases, such as when running only a single thread is required. +//! +//! #### Current-Thread Scheduler +//! +//! The current-thread scheduler provides a _single-threaded_ future executor. +//! All tasks will be created and executed on the current thread. This requires +//! the `rt` feature flag. +//! ``` +//! use tokio::runtime; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let basic_rt = runtime::Builder::new_current_thread() +//! .build()?; +//! # Ok(()) } +//! ``` +//! +//! #### Resource drivers +//! +//! When configuring a runtime by hand, no resource drivers are enabled by +//! default. In this case, attempting to use networking types or time types will +//! fail. In order to enable these types, the resource drivers must be enabled. +//! This is done with [`Builder::enable_io`] and [`Builder::enable_time`]. As a +//! shorthand, [`Builder::enable_all`] enables both resource drivers. +//! +//! ## Lifetime of spawned threads +//! +//! The runtime may spawn threads depending on its configuration and usage. The +//! multi-thread scheduler spawns threads to schedule tasks and for `spawn_blocking` +//! calls. +//! +//! While the `Runtime` is active, threads may shutdown after periods of being +//! idle. Once `Runtime` is dropped, all runtime threads are forcibly shutdown. +//! Any tasks that have not yet completed will be dropped. +//! +//! [tasks]: crate::task +//! [`Runtime`]: Runtime +//! [`tokio::spawn`]: crate::spawn +//! [`tokio::main`]: ../attr.main.html +//! [runtime builder]: crate::runtime::Builder +//! [`Runtime::new`]: crate::runtime::Runtime::new +//! [`Builder::basic_scheduler`]: crate::runtime::Builder::basic_scheduler +//! [`Builder::threaded_scheduler`]: crate::runtime::Builder::threaded_scheduler +//! [`Builder::enable_io`]: crate::runtime::Builder::enable_io +//! [`Builder::enable_time`]: crate::runtime::Builder::enable_time +//! [`Builder::enable_all`]: crate::runtime::Builder::enable_all + +// At the top due to macros +#[cfg(test)] +#[cfg(not(target_arch = "wasm32"))] +#[macro_use] +mod tests; + +pub(crate) mod enter; + +pub(crate) mod task; + +cfg_metrics! { + mod metrics; + pub use metrics::RuntimeMetrics; + + pub(crate) use metrics::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; +} + +cfg_not_metrics! { + pub(crate) mod metrics; + pub(crate) use metrics::{SchedulerMetrics, WorkerMetrics, MetricsBatch}; +} + +cfg_rt! { + mod basic_scheduler; + use basic_scheduler::BasicScheduler; + + mod blocking; + use blocking::BlockingPool; + pub(crate) use blocking::spawn_blocking; + + cfg_trace! { + pub(crate) use blocking::Mandatory; + } + + cfg_fs! { + pub(crate) use blocking::spawn_mandatory_blocking; + } + + mod builder; + pub use self::builder::Builder; + + pub(crate) mod context; + pub(crate) mod driver; + + use self::enter::enter; + + mod handle; + pub use handle::{EnterGuard, Handle, TryCurrentError}; + + mod spawner; + use self::spawner::Spawner; +} + +cfg_rt_multi_thread! { + mod park; + use park::Parker; +} + +cfg_rt_multi_thread! { + mod queue; + + pub(crate) mod thread_pool; + use self::thread_pool::ThreadPool; +} + +cfg_rt! { + use crate::task::JoinHandle; + + use std::future::Future; + use std::time::Duration; + + /// The Tokio runtime. + /// + /// The runtime provides an I/O driver, task scheduler, [timer], and + /// blocking pool, necessary for running asynchronous tasks. + /// + /// Instances of `Runtime` can be created using [`new`], or [`Builder`]. + /// However, most users will use the `#[tokio::main]` annotation on their + /// entry point instead. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Shutdown + /// + /// Shutting down the runtime is done by dropping the value. The current + /// thread will block until the shut down operation has completed. + /// + /// * Drain any scheduled work queues. + /// * Drop any futures that have not yet completed. + /// * Drop the reactor. + /// + /// Once the reactor has dropped, any outstanding I/O resources bound to + /// that reactor will no longer function. Calling any method on them will + /// result in an error. + /// + /// # Sharing + /// + /// The Tokio runtime implements `Sync` and `Send` to allow you to wrap it + /// in a `Arc`. Most fn take `&self` to allow you to call them concurrently + /// across multiple threads. + /// + /// Calls to `shutdown` and `shutdown_timeout` require exclusive ownership of + /// the runtime type and this can be achieved via `Arc::try_unwrap` when only + /// one strong count reference is left over. + /// + /// [timer]: crate::time + /// [mod]: index.html + /// [`new`]: method@Self::new + /// [`Builder`]: struct@Builder + #[derive(Debug)] + pub struct Runtime { + /// Task executor + kind: Kind, + + /// Handle to runtime, also contains driver handles + handle: Handle, + + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, + } + + /// The runtime executor is either a thread-pool or a current-thread executor. + #[derive(Debug)] + enum Kind { + /// Execute all tasks on the current-thread. + CurrentThread(BasicScheduler), + + /// Execute tasks across multiple threads. + #[cfg(feature = "rt-multi-thread")] + ThreadPool(ThreadPool), + } + + /// After thread starts / before thread stops + type Callback = std::sync::Arc<dyn Fn() + Send + Sync>; + + impl Runtime { + /// Creates a new runtime instance with default configuration values. + /// + /// This results in the multi threaded scheduler, I/O driver, and time driver being + /// initialized. + /// + /// Most applications will not need to call this function directly. Instead, + /// they will use the [`#[tokio::main]` attribute][main]. When a more complex + /// configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `Runtime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: index.html + /// [main]: ../attr.main.html + /// [threaded scheduler]: index.html#threaded-scheduler + /// [basic scheduler]: index.html#basic-scheduler + /// [runtime builder]: crate::runtime::Builder + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new() -> std::io::Result<Runtime> { + Builder::new_multi_thread().enable_all().build() + } + + /// Returns a handle to the runtime's spawner. + /// + /// The returned handle can be used to spawn tasks that run on this runtime, and can + /// be cloned to allow moving the `Handle` to other threads. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new() + /// .unwrap(); + /// + /// let handle = rt.handle(); + /// + /// // Use the handle... + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Spawns a future onto the Tokio runtime. + /// + /// This spawns the given future onto the runtime's executor, usually a + /// thread pool. The thread pool is then responsible for polling the future + /// until it completes. + /// + /// See [module level][mod] documentation for more details. + /// + /// [mod]: index.html + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.handle.spawn(future) + } + + /// Runs the provided function on an executor dedicated to blocking operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + #[track_caller] + pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// This runs the given future on the current thread, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers + /// which the future spawns internally will be executed on the runtime. + /// + /// # Multi thread scheduler + /// + /// When the multi thread scheduler is used this will allow futures + /// to run within the io driver and timer context of the overall runtime. + /// + /// # Current thread scheduler + /// + /// When the current thread scheduler is enabled `block_on` + /// can be called concurrently from multiple threads. The first call + /// will take ownership of the io and timer drivers. This means + /// other threads which do not own the drivers will hook into that one. + /// When the first `block_on` completes, other threads will be able to + /// "steal" the driver to allow continued execution of their futures. + /// + /// # Panics + /// + /// This function panics if the provided future panics, or if called within an + /// asynchronous execution context. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// [handle]: fn@Handle::block_on + #[track_caller] + pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + + let _enter = self.enter(); + + match &self.kind { + Kind::CurrentThread(exec) => exec.block_on(future), + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(exec) => exec.block_on(future), + } + } + + /// Enters the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn function_that_spawns(msg: String) { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }); + /// } + /// + /// fn main() { + /// let rt = Runtime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// function_that_spawns(s); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + self.handle.enter() + } + + /// Shuts down the runtime, waiting for at most `duration` for all spawned + /// task to shutdown. + /// + /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to + /// shutdown in a timely fashion. However, dropping a `Runtime` will wait + /// indefinitely for all tasks to terminate, and there are cases where a long + /// blocking task has been spawned, which can block dropping `Runtime`. + /// + /// In this case, calling `shutdown_timeout` with an explicit wait timeout + /// can work. The `shutdown_timeout` will signal all tasks to shutdown and + /// will wait for at most `duration` for all spawned tasks to terminate. If + /// `timeout` elapses before all tasks are dropped, the function returns and + /// outstanding tasks are potentially leaked. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.clone().shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shuts down the runtime, without waiting for any spawned tasks to shutdown. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = Runtime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)) + } + } + + #[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let + impl Drop for Runtime { + fn drop(&mut self) { + match &mut self.kind { + Kind::CurrentThread(basic) => { + // This ensures that tasks spawned on the basic runtime are dropped inside the + // runtime's context. + match self::context::try_enter(self.handle.clone()) { + Some(guard) => basic.set_context_guard(guard), + None => { + // The context thread-local has already been destroyed. + // + // We don't set the guard in this case. Calls to tokio::spawn in task + // destructors would fail regardless if this happens. + }, + } + }, + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(_) => { + // The threaded scheduler drops its tasks on its worker threads, which is + // already in the runtime's context. + }, + } + } + } + + cfg_metrics! { + impl Runtime { + /// TODO + pub fn metrics(&self) -> RuntimeMetrics { + self.handle.metrics() + } + } + } +} diff --git a/third_party/rust/tokio/src/runtime/park.rs b/third_party/rust/tokio/src/runtime/park.rs new file mode 100644 index 0000000000..033b9f20be --- /dev/null +++ b/third_party/rust/tokio/src/runtime/park.rs @@ -0,0 +1,257 @@ +//! Parks the runtime. +//! +//! A combination of the various resource driver park handles. + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::thread; +use crate::park::{Park, Unpark}; +use crate::runtime::driver::Driver; +use crate::util::TryLock; + +use std::sync::atomic::Ordering::SeqCst; +use std::time::Duration; + +pub(crate) struct Parker { + inner: Arc<Inner>, +} + +pub(crate) struct Unparker { + inner: Arc<Inner>, +} + +struct Inner { + /// Avoids entering the park if possible + state: AtomicUsize, + + /// Used to coordinate access to the driver / condvar + mutex: Mutex<()>, + + /// Condvar to block on if the driver is unavailable. + condvar: Condvar, + + /// Resource (I/O, time, ...) driver + shared: Arc<Shared>, +} + +const EMPTY: usize = 0; +const PARKED_CONDVAR: usize = 1; +const PARKED_DRIVER: usize = 2; +const NOTIFIED: usize = 3; + +/// Shared across multiple Parker handles +struct Shared { + /// Shared driver. Only one thread at a time can use this + driver: TryLock<Driver>, + + /// Unpark handle + handle: <Driver as Park>::Unpark, +} + +impl Parker { + pub(crate) fn new(driver: Driver) -> Parker { + let handle = driver.unpark(); + + Parker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + mutex: Mutex::new(()), + condvar: Condvar::new(), + shared: Arc::new(Shared { + driver: TryLock::new(driver), + handle, + }), + }), + } + } +} + +impl Clone for Parker { + fn clone(&self) -> Parker { + Parker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + mutex: Mutex::new(()), + condvar: Condvar::new(), + shared: self.inner.shared.clone(), + }), + } + } +} + +impl Park for Parker { + type Unpark = Unparker; + type Error = (); + + fn unpark(&self) -> Unparker { + Unparker { + inner: self.inner.clone(), + } + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + // Only parking with zero is supported... + assert_eq!(duration, Duration::from_millis(0)); + + if let Some(mut driver) = self.inner.shared.driver.try_lock() { + driver.park_timeout(duration).map_err(|_| ()) + } else { + Ok(()) + } + } + + fn shutdown(&mut self) { + self.inner.shutdown(); + } +} + +impl Unpark for Unparker { + fn unpark(&self) { + self.inner.unpark(); + } +} + +impl Inner { + /// Parks the current thread for at most `dur`. + fn park(&self) { + for _ in 0..3 { + // If we were previously notified then we consume this notification and + // return quickly. + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + + thread::yield_now(); + } + + if let Some(mut driver) = self.shared.driver.try_lock() { + self.park_driver(&mut driver); + } else { + self.park_condvar(); + } + } + + fn park_condvar(&self) { + // Otherwise we need to coordinate going to sleep + let mut m = self.mutex.lock(); + + match self + .state + .compare_exchange(EMPTY, PARKED_CONDVAR, SeqCst, SeqCst) + { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park state; actual = {}", actual), + } + + loop { + m = self.condvar.wait(m).unwrap(); + + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + // got a notification + return; + } + + // spurious wakeup, go back to sleep + } + } + + fn park_driver(&self, driver: &mut Driver) { + match self + .state + .compare_exchange(EMPTY, PARKED_DRIVER, SeqCst, SeqCst) + { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park state; actual = {}", actual), + } + + // TODO: don't unwrap + driver.park().unwrap(); + + match self.state.swap(EMPTY, SeqCst) { + NOTIFIED => {} // got a notification, hurray! + PARKED_DRIVER => {} // no notification, alas + n => panic!("inconsistent park_timeout state: {}", n), + } + } + + fn unpark(&self) { + // To ensure the unparked thread will observe any writes we made before + // this call, we must perform a release operation that `park` can + // synchronize with. To do that we must write `NOTIFIED` even if `state` + // is already `NOTIFIED`. That is why this must be a swap rather than a + // compare-and-swap that returns if it reads `NOTIFIED` on failure. + match self.state.swap(NOTIFIED, SeqCst) { + EMPTY => {} // no one was waiting + NOTIFIED => {} // already unparked + PARKED_CONDVAR => self.unpark_condvar(), + PARKED_DRIVER => self.unpark_driver(), + actual => panic!("inconsistent state in unpark; actual = {}", actual), + } + } + + fn unpark_condvar(&self) { + // There is a period between when the parked thread sets `state` to + // `PARKED` (or last checked `state` in the case of a spurious wake + // up) and when it actually waits on `cvar`. If we were to notify + // during this period it would be ignored and then when the parked + // thread went to sleep it would never wake up. Fortunately, it has + // `lock` locked at this stage so we can acquire `lock` to wait until + // it is ready to receive the notification. + // + // Releasing `lock` before the call to `notify_one` means that when the + // parked thread wakes it doesn't get woken only to have to wait for us + // to release `lock`. + drop(self.mutex.lock()); + + self.condvar.notify_one() + } + + fn unpark_driver(&self) { + self.shared.handle.unpark(); + } + + fn shutdown(&self) { + if let Some(mut driver) = self.shared.driver.try_lock() { + driver.shutdown(); + } + + self.condvar.notify_all(); + } +} diff --git a/third_party/rust/tokio/src/runtime/queue.rs b/third_party/rust/tokio/src/runtime/queue.rs new file mode 100644 index 0000000000..ad9085a654 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/queue.rs @@ -0,0 +1,511 @@ +//! Run-queue structures to support a work-stealing scheduler + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicU16, AtomicU32}; +use crate::loom::sync::Arc; +use crate::runtime::task::{self, Inject}; +use crate::runtime::MetricsBatch; + +use std::mem::MaybeUninit; +use std::ptr; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; + +/// Producer handle. May only be used from a single thread. +pub(super) struct Local<T: 'static> { + inner: Arc<Inner<T>>, +} + +/// Consumer handle. May be used from many threads. +pub(super) struct Steal<T: 'static>(Arc<Inner<T>>); + +pub(super) struct Inner<T: 'static> { + /// Concurrently updated by many threads. + /// + /// Contains two `u16` values. The LSB byte is the "real" head of the queue. + /// The `u16` in the MSB is set by a stealer in process of stealing values. + /// It represents the first value being stolen in the batch. `u16` is used + /// in order to distinguish between `head == tail` and `head == tail - + /// capacity`. + /// + /// When both `u16` values are the same, there is no active stealer. + /// + /// Tracking an in-progress stealer prevents a wrapping scenario. + head: AtomicU32, + + /// Only updated by producer thread but read by many threads. + tail: AtomicU16, + + /// Elements + buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>, +} + +unsafe impl<T> Send for Inner<T> {} +unsafe impl<T> Sync for Inner<T> {} + +#[cfg(not(loom))] +const LOCAL_QUEUE_CAPACITY: usize = 256; + +// Shrink the size of the local queue when using loom. This shouldn't impact +// logic, but allows loom to test more edge cases in a reasonable a mount of +// time. +#[cfg(loom)] +const LOCAL_QUEUE_CAPACITY: usize = 4; + +const MASK: usize = LOCAL_QUEUE_CAPACITY - 1; + +// Constructing the fixed size array directly is very awkward. The only way to +// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as +// the contents are not Copy. The trick with defining a const doesn't work for +// generic types. +fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { + assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY); + + // safety: We check that the length is correct. + unsafe { Box::from_raw(Box::into_raw(buffer).cast()) } +} + +/// Create a new local run-queue +pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { + let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); + + for _ in 0..LOCAL_QUEUE_CAPACITY { + buffer.push(UnsafeCell::new(MaybeUninit::uninit())); + } + + let inner = Arc::new(Inner { + head: AtomicU32::new(0), + tail: AtomicU16::new(0), + buffer: make_fixed_size(buffer.into_boxed_slice()), + }); + + let local = Local { + inner: inner.clone(), + }; + + let remote = Steal(inner); + + (remote, local) +} + +impl<T> Local<T> { + /// Returns true if the queue has entries that can be stealed. + pub(super) fn is_stealable(&self) -> bool { + !self.inner.is_empty() + } + + /// Returns false if there are any entries in the queue + /// + /// Separate to is_stealable so that refactors of is_stealable to "protect" + /// some tasks from stealing won't affect this + pub(super) fn has_tasks(&self) -> bool { + !self.inner.is_empty() + } + + /// Pushes a task to the back of the local queue, skipping the LIFO slot. + pub(super) fn push_back( + &mut self, + mut task: task::Notified<T>, + inject: &Inject<T>, + metrics: &mut MetricsBatch, + ) { + let tail = loop { + let head = self.inner.head.load(Acquire); + let (steal, real) = unpack(head); + + // safety: this is the **only** thread that updates this cell. + let tail = unsafe { self.inner.tail.unsync_load() }; + + if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 { + // There is capacity for the task + break tail; + } else if steal != real { + // Concurrently stealing, this will free up capacity, so only + // push the task onto the inject queue + inject.push(task); + return; + } else { + // Push the current task and half of the queue into the + // inject queue. + match self.push_overflow(task, real, tail, inject, metrics) { + Ok(_) => return, + // Lost the race, try again + Err(v) => { + task = v; + } + } + } + }; + + // Map the position to a slot index. + let idx = tail as usize & MASK; + + self.inner.buffer[idx].with_mut(|ptr| { + // Write the task to the slot + // + // Safety: There is only one producer and the above `if` + // condition ensures we don't touch a cell if there is a + // value, thus no consumer. + unsafe { + ptr::write((*ptr).as_mut_ptr(), task); + } + }); + + // Make the task available. Synchronizes with a load in + // `steal_into2`. + self.inner.tail.store(tail.wrapping_add(1), Release); + } + + /// Moves a batch of tasks into the inject queue. + /// + /// This will temporarily make some of the tasks unavailable to stealers. + /// Once `push_overflow` is done, a notification is sent out, so if other + /// workers "missed" some of the tasks during a steal, they will get + /// another opportunity. + #[inline(never)] + fn push_overflow( + &mut self, + task: task::Notified<T>, + head: u16, + tail: u16, + inject: &Inject<T>, + metrics: &mut MetricsBatch, + ) -> Result<(), task::Notified<T>> { + /// How many elements are we taking from the local queue. + /// + /// This is one less than the number of tasks pushed to the inject + /// queue as we are also inserting the `task` argument. + const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16; + + assert_eq!( + tail.wrapping_sub(head) as usize, + LOCAL_QUEUE_CAPACITY, + "queue is not full; tail = {}; head = {}", + tail, + head + ); + + let prev = pack(head, head); + + // Claim a bunch of tasks + // + // We are claiming the tasks **before** reading them out of the buffer. + // This is safe because only the **current** thread is able to push new + // tasks. + // + // There isn't really any need for memory ordering... Relaxed would + // work. This is because all tasks are pushed into the queue from the + // current thread (or memory has been acquired if the local queue handle + // moved). + if self + .inner + .head + .compare_exchange( + prev, + pack( + head.wrapping_add(NUM_TASKS_TAKEN), + head.wrapping_add(NUM_TASKS_TAKEN), + ), + Release, + Relaxed, + ) + .is_err() + { + // We failed to claim the tasks, losing the race. Return out of + // this function and try the full `push` routine again. The queue + // may not be full anymore. + return Err(task); + } + + /// An iterator that takes elements out of the run queue. + struct BatchTaskIter<'a, T: 'static> { + buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY], + head: u32, + i: u32, + } + impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> { + type Item = task::Notified<T>; + + #[inline] + fn next(&mut self) -> Option<task::Notified<T>> { + if self.i == u32::from(NUM_TASKS_TAKEN) { + None + } else { + let i_idx = self.i.wrapping_add(self.head) as usize & MASK; + let slot = &self.buffer[i_idx]; + + // safety: Our CAS from before has assumed exclusive ownership + // of the task pointers in this range. + let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + self.i += 1; + Some(task) + } + } + } + + // safety: The CAS above ensures that no consumer will look at these + // values again, and we are the only producer. + let batch_iter = BatchTaskIter { + buffer: &*self.inner.buffer, + head: head as u32, + i: 0, + }; + inject.push_batch(batch_iter.chain(std::iter::once(task))); + + // Add 1 to factor in the task currently being scheduled. + metrics.incr_overflow_count(); + + Ok(()) + } + + /// Pops a task from the local queue. + pub(super) fn pop(&mut self) -> Option<task::Notified<T>> { + let mut head = self.inner.head.load(Acquire); + + let idx = loop { + let (steal, real) = unpack(head); + + // safety: this is the **only** thread that updates this cell. + let tail = unsafe { self.inner.tail.unsync_load() }; + + if real == tail { + // queue is empty + return None; + } + + let next_real = real.wrapping_add(1); + + // If `steal == real` there are no concurrent stealers. Both `steal` + // and `real` are updated. + let next = if steal == real { + pack(next_real, next_real) + } else { + assert_ne!(steal, next_real); + pack(steal, next_real) + }; + + // Attempt to claim a task. + let res = self + .inner + .head + .compare_exchange(head, next, AcqRel, Acquire); + + match res { + Ok(_) => break real as usize & MASK, + Err(actual) => head = actual, + } + }; + + Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() })) + } +} + +impl<T> Steal<T> { + pub(super) fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Steals half the tasks from self and place them into `dst`. + pub(super) fn steal_into( + &self, + dst: &mut Local<T>, + dst_metrics: &mut MetricsBatch, + ) -> Option<task::Notified<T>> { + // Safety: the caller is the only thread that mutates `dst.tail` and + // holds a mutable reference. + let dst_tail = unsafe { dst.inner.tail.unsync_load() }; + + // To the caller, `dst` may **look** empty but still have values + // contained in the buffer. If another thread is concurrently stealing + // from `dst` there may not be enough capacity to steal. + let (steal, _) = unpack(dst.inner.head.load(Acquire)); + + if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 { + // we *could* try to steal less here, but for simplicity, we're just + // going to abort. + return None; + } + + // Steal the tasks into `dst`'s buffer. This does not yet expose the + // tasks in `dst`. + let mut n = self.steal_into2(dst, dst_tail); + + if n == 0 { + // No tasks were stolen + return None; + } + + dst_metrics.incr_steal_count(n); + + // We are returning a task here + n -= 1; + + let ret_pos = dst_tail.wrapping_add(n); + let ret_idx = ret_pos as usize & MASK; + + // safety: the value was written as part of `steal_into2` and not + // exposed to stealers, so no other thread can access it. + let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + if n == 0 { + // The `dst` queue is empty, but a single task was stolen + return Some(ret); + } + + // Make the stolen items available to consumers + dst.inner.tail.store(dst_tail.wrapping_add(n), Release); + + Some(ret) + } + + // Steal tasks from `self`, placing them into `dst`. Returns the number of + // tasks that were stolen. + fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 { + let mut prev_packed = self.0.head.load(Acquire); + let mut next_packed; + + let n = loop { + let (src_head_steal, src_head_real) = unpack(prev_packed); + let src_tail = self.0.tail.load(Acquire); + + // If these two do not match, another thread is concurrently + // stealing from the queue. + if src_head_steal != src_head_real { + return 0; + } + + // Number of available tasks to steal + let n = src_tail.wrapping_sub(src_head_real); + let n = n - n / 2; + + if n == 0 { + // No tasks available to steal + return 0; + } + + // Update the real head index to acquire the tasks. + let steal_to = src_head_real.wrapping_add(n); + assert_ne!(src_head_steal, steal_to); + next_packed = pack(src_head_steal, steal_to); + + // Claim all those tasks. This is done by incrementing the "real" + // head but not the steal. By doing this, no other thread is able to + // steal from this queue until the current thread completes. + let res = self + .0 + .head + .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); + + match res { + Ok(_) => break n, + Err(actual) => prev_packed = actual, + } + }; + + assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n); + + let (first, _) = unpack(next_packed); + + // Take all the tasks + for i in 0..n { + // Compute the positions + let src_pos = first.wrapping_add(i); + let dst_pos = dst_tail.wrapping_add(i); + + // Map to slots + let src_idx = src_pos as usize & MASK; + let dst_idx = dst_pos as usize & MASK; + + // Read the task + // + // safety: We acquired the task with the atomic exchange above. + let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + // Write the task to the new slot + // + // safety: `dst` queue is empty and we are the only producer to + // this queue. + dst.inner.buffer[dst_idx] + .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) }); + } + + let mut prev_packed = next_packed; + + // Update `src_head_steal` to match `src_head_real` signalling that the + // stealing routine is complete. + loop { + let head = unpack(prev_packed).1; + next_packed = pack(head, head); + + let res = self + .0 + .head + .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); + + match res { + Ok(_) => return n, + Err(actual) => { + let (actual_steal, actual_real) = unpack(actual); + + assert_ne!(actual_steal, actual_real); + + prev_packed = actual; + } + } + } + } +} + +cfg_metrics! { + impl<T> Steal<T> { + pub(crate) fn len(&self) -> usize { + self.0.len() as _ + } + } +} + +impl<T> Clone for Steal<T> { + fn clone(&self) -> Steal<T> { + Steal(self.0.clone()) + } +} + +impl<T> Drop for Local<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +impl<T> Inner<T> { + fn len(&self) -> u16 { + let (_, head) = unpack(self.head.load(Acquire)); + let tail = self.tail.load(Acquire); + + tail.wrapping_sub(head) + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Split the head value into the real head and the index a stealer is working +/// on. +fn unpack(n: u32) -> (u16, u16) { + let real = n & u16::MAX as u32; + let steal = n >> 16; + + (steal as u16, real as u16) +} + +/// Join the two head values +fn pack(steal: u16, real: u16) -> u32 { + (real as u32) | ((steal as u32) << 16) +} + +#[test] +fn test_local_queue_capacity() { + assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize); +} diff --git a/third_party/rust/tokio/src/runtime/spawner.rs b/third_party/rust/tokio/src/runtime/spawner.rs new file mode 100644 index 0000000000..d81a806cb5 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/spawner.rs @@ -0,0 +1,83 @@ +use crate::future::Future; +use crate::runtime::basic_scheduler; +use crate::task::JoinHandle; + +cfg_rt_multi_thread! { + use crate::runtime::thread_pool; +} + +#[derive(Debug, Clone)] +pub(crate) enum Spawner { + Basic(basic_scheduler::Spawner), + #[cfg(feature = "rt-multi-thread")] + ThreadPool(thread_pool::Spawner), +} + +impl Spawner { + pub(crate) fn shutdown(&mut self) { + #[cfg(feature = "rt-multi-thread")] + { + if let Spawner::ThreadPool(spawner) = self { + spawner.shutdown(); + } + } + } + + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match self { + Spawner::Basic(spawner) => spawner.spawn(future), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.spawn(future), + } + } +} + +cfg_metrics! { + use crate::runtime::{SchedulerMetrics, WorkerMetrics}; + + impl Spawner { + pub(crate) fn num_workers(&self) -> usize { + match self { + Spawner::Basic(_) => 1, + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.num_workers(), + } + } + + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + match self { + Spawner::Basic(spawner) => spawner.scheduler_metrics(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.scheduler_metrics(), + } + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + match self { + Spawner::Basic(spawner) => spawner.worker_metrics(worker), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.worker_metrics(worker), + } + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + match self { + Spawner::Basic(spawner) => spawner.injection_queue_depth(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.injection_queue_depth(), + } + } + + pub(crate) fn worker_local_queue_depth(&self, worker: usize) -> usize { + match self { + Spawner::Basic(spawner) => spawner.worker_metrics(worker).queue_depth(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.worker_local_queue_depth(worker), + } + } + } +} diff --git a/third_party/rust/tokio/src/runtime/task/core.rs b/third_party/rust/tokio/src/runtime/task/core.rs new file mode 100644 index 0000000000..776e8341f3 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/core.rs @@ -0,0 +1,267 @@ +//! Core task module. +//! +//! # Safety +//! +//! The functions in this module are private to the `task` module. All of them +//! should be considered `unsafe` to use, but are not marked as such since it +//! would be too noisy. +//! +//! Make sure to consult the relevant safety section of each function before +//! use. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::runtime::task::raw::{self, Vtable}; +use crate::runtime::task::state::State; +use crate::runtime::task::Schedule; +use crate::util::linked_list; + +use std::pin::Pin; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// The task cell. Contains the components of the task. +/// +/// It is critical for `Header` to be the first field as the task structure will +/// be referenced by both *mut Cell and *mut Header. +#[repr(C)] +pub(super) struct Cell<T: Future, S> { + /// Hot task state data + pub(super) header: Header, + + /// Either the future or output, depending on the execution stage. + pub(super) core: Core<T, S>, + + /// Cold data + pub(super) trailer: Trailer, +} + +pub(super) struct CoreStage<T: Future> { + stage: UnsafeCell<Stage<T>>, +} + +/// The core of the task. +/// +/// Holds the future or output, depending on the stage of execution. +pub(super) struct Core<T: Future, S> { + /// Scheduler used to drive this future. + pub(super) scheduler: S, + + /// Either the future or the output. + pub(super) stage: CoreStage<T>, +} + +/// Crate public as this is also needed by the pool. +#[repr(C)] +pub(crate) struct Header { + /// Task state. + pub(super) state: State, + + pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, + + /// Pointer to next task, used with the injection queue. + pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, + + /// Table of function pointers for executing actions on the task. + pub(super) vtable: &'static Vtable, + + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell<u64>, + + /// The tracing ID for this instrumented task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) id: Option<tracing::Id>, +} + +unsafe impl Send for Header {} +unsafe impl Sync for Header {} + +/// Cold data is stored after the future. +pub(super) struct Trailer { + /// Consumer task waiting on completion of this task. + pub(super) waker: UnsafeCell<Option<Waker>>, +} + +/// Either the future or the output. +pub(super) enum Stage<T: Future> { + Running(T), + Finished(super::Result<T::Output>), + Consumed, +} + +impl<T: Future, S: Schedule> Cell<T, S> { + /// Allocates a new task cell, containing the header, trailer, and core + /// structures. + pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let id = future.id(); + Box::new(Cell { + header: Header { + state, + owned: UnsafeCell::new(linked_list::Pointers::new()), + queue_next: UnsafeCell::new(None), + vtable: raw::vtable::<T, S>(), + owner_id: UnsafeCell::new(0), + #[cfg(all(tokio_unstable, feature = "tracing"))] + id, + }, + core: Core { + scheduler, + stage: CoreStage { + stage: UnsafeCell::new(Stage::Running(future)), + }, + }, + trailer: Trailer { + waker: UnsafeCell::new(None), + }, + }) + } +} + +impl<T: Future> CoreStage<T> { + pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { + self.stage.with_mut(f) + } + + /// Polls the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `state` field. This + /// requires ensuring mutual exclusion between any concurrent thread that + /// might modify the future or output field. + /// + /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` + /// component of the task state. + /// + /// `self` must also be pinned. This is handled by storing the task on the + /// heap. + pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> { + let res = { + self.stage.with_mut(|ptr| { + // Safety: The caller ensures mutual exclusion to the field. + let future = match unsafe { &mut *ptr } { + Stage::Running(future) => future, + _ => unreachable!("unexpected stage"), + }; + + // Safety: The caller ensures the future is pinned. + let future = unsafe { Pin::new_unchecked(future) }; + + future.poll(&mut cx) + }) + }; + + if res.is_ready() { + self.drop_future_or_output(); + } + + res + } + + /// Drops the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn drop_future_or_output(&self) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Consumed); + } + } + + /// Stores the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn store_output(&self, output: super::Result<T::Output>) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Finished(output)); + } + } + + /// Takes the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn take_output(&self) -> super::Result<T::Output> { + use std::mem; + + self.stage.with_mut(|ptr| { + // Safety:: the caller ensures mutual exclusion to the field. + match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { + Stage::Finished(output) => output, + _ => panic!("JoinHandle polled after completion"), + } + }) + } + + unsafe fn set_stage(&self, stage: Stage<T>) { + self.stage.with_mut(|ptr| *ptr = stage) + } +} + +cfg_rt_multi_thread! { + impl Header { + pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { + self.queue_next.with_mut(|ptr| *ptr = next); + } + } +} + +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + +impl Trailer { + pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { + self.waker.with_mut(|ptr| { + *ptr = waker; + }); + } + + pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { + self.waker + .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) + } + + pub(super) fn wake_join(&self) { + self.waker.with(|ptr| match unsafe { &*ptr } { + Some(waker) => waker.wake_by_ref(), + None => panic!("waker missing"), + }); + } +} + +#[test] +#[cfg(not(loom))] +fn header_lte_cache_line() { + use std::mem::size_of; + + assert!(size_of::<Header>() <= 8 * size_of::<*const ()>()); +} diff --git a/third_party/rust/tokio/src/runtime/task/error.rs b/third_party/rust/tokio/src/runtime/task/error.rs new file mode 100644 index 0000000000..1a8129b2b6 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/error.rs @@ -0,0 +1,146 @@ +use std::any::Any; +use std::fmt; +use std::io; + +use crate::util::SyncWrapper; + +cfg_rt! { + /// Task failed to execute to completion. + pub struct JoinError { + repr: Repr, + } +} + +enum Repr { + Cancelled, + Panic(SyncWrapper<Box<dyn Any + Send + 'static>>), +} + +impl JoinError { + pub(crate) fn cancelled() -> JoinError { + JoinError { + repr: Repr::Cancelled, + } + } + + pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { + JoinError { + repr: Repr::Panic(SyncWrapper::new(err)), + } + } + + /// Returns true if the error was caused by the task being cancelled. + pub fn is_cancelled(&self) -> bool { + matches!(&self.repr, Repr::Cancelled) + } + + /// Returns true if the error was caused by the task panicking. + /// + /// # Examples + /// + /// ``` + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// assert!(err.is_panic()); + /// } + /// ``` + pub fn is_panic(&self) -> bool { + matches!(&self.repr, Repr::Panic(_)) + } + + /// Consumes the join error, returning the object with which the task panicked. + /// + /// # Panics + /// + /// `into_panic()` panics if the `Error` does not represent the underlying + /// task terminating with a panic. Use `is_panic` to check the error reason + /// or `try_into_panic` for a variant that does not panic. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if err.is_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(err.into_panic()); + /// } + /// } + /// ``` + pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { + self.try_into_panic() + .expect("`JoinError` reason is not a panic.") + } + + /// Consumes the join error, returning the object with which the task + /// panicked if the task terminated due to a panic. Otherwise, `self` is + /// returned. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if let Ok(reason) = err.try_into_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(reason); + /// } + /// } + /// ``` + pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> { + match self.repr { + Repr::Panic(p) => Ok(p.into_inner()), + _ => Err(self), + } + } +} + +impl fmt::Display for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "cancelled"), + Repr::Panic(_) => write!(fmt, "panic"), + } + } +} + +impl fmt::Debug for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), + Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + } + } +} + +impl std::error::Error for JoinError {} + +impl From<JoinError> for io::Error { + fn from(src: JoinError) -> io::Error { + io::Error::new( + io::ErrorKind::Other, + match src.repr { + Repr::Cancelled => "task was cancelled", + Repr::Panic(_) => "task panicked", + }, + ) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/harness.rs b/third_party/rust/tokio/src/runtime/task/harness.rs new file mode 100644 index 0000000000..261dccea41 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/harness.rs @@ -0,0 +1,485 @@ +use crate::future::Future; +use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; +use crate::runtime::task::state::Snapshot; +use crate::runtime::task::waker::waker_ref; +use crate::runtime::task::{JoinError, Notified, Schedule, Task}; + +use std::mem; +use std::mem::ManuallyDrop; +use std::panic; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// Typed raw task handle. +pub(super) struct Harness<T: Future, S: 'static> { + cell: NonNull<Cell<T, S>>, +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: 'static, +{ + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> { + Harness { + cell: ptr.cast::<Cell<T, S>>(), + } + } + + fn header_ptr(&self) -> NonNull<Header> { + self.cell.cast() + } + + fn header(&self) -> &Header { + unsafe { &self.cell.as_ref().header } + } + + fn trailer(&self) -> &Trailer { + unsafe { &self.cell.as_ref().trailer } + } + + fn core(&self) -> &Core<T, S> { + unsafe { &self.cell.as_ref().core } + } +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: Schedule, +{ + /// Polls the inner future. A ref-count is consumed. + /// + /// All necessary state checks and transitions are performed. + /// Panics raised while polling the future are handled. + pub(super) fn poll(self) { + // We pass our ref-count to `poll_inner`. + match self.poll_inner() { + PollFuture::Notified => { + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + self.core() + .scheduler + .yield_now(Notified(self.get_new_task())); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. + self.drop_reference(); + } + PollFuture::Complete => { + self.complete(); + } + PollFuture::Dealloc => { + self.dealloc(); + } + PollFuture::Done => (), + } + } + + /// Polls the task and cancel it if necessary. This takes ownership of a + /// ref-count. + /// + /// If the return value is Notified, the caller is given ownership of two + /// ref-counts. + /// + /// If the return value is Complete, the caller is given ownership of a + /// single ref-count, which should be passed on to `complete`. + /// + /// If the return value is Dealloc, then this call consumed the last + /// ref-count and the caller should call `dealloc`. + /// + /// Otherwise the ref-count is consumed and the caller should not access + /// `self` again. + fn poll_inner(&self) -> PollFuture { + use super::state::{TransitionToIdle, TransitionToRunning}; + + match self.header().state.transition_to_running() { + TransitionToRunning::Success => { + let header_ptr = self.header_ptr(); + let waker_ref = waker_ref::<T, S>(&header_ptr); + let cx = Context::from_waker(&*waker_ref); + let res = poll_future(&self.core().stage, cx); + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollFuture::Complete; + } + + match self.header().state.transition_to_idle() { + TransitionToIdle::Ok => PollFuture::Done, + TransitionToIdle::OkNotified => PollFuture::Notified, + TransitionToIdle::OkDealloc => PollFuture::Dealloc, + TransitionToIdle::Cancelled => { + // The transition to idle failed because the task was + // cancelled during the poll. + + cancel_task(&self.core().stage); + PollFuture::Complete + } + } + } + TransitionToRunning::Cancelled => { + cancel_task(&self.core().stage); + PollFuture::Complete + } + TransitionToRunning::Failed => PollFuture::Done, + TransitionToRunning::Dealloc => PollFuture::Dealloc, + } + } + + /// Forcibly shuts down the task. + /// + /// Attempt to transition to `Running` in order to forcibly shutdown the + /// task. If the task is currently running or in a state of completion, then + /// there is nothing further to do. When the task completes running, it will + /// notice the `CANCELLED` bit and finalize the task. + pub(super) fn shutdown(self) { + if !self.header().state.transition_to_shutdown() { + // The task is concurrently running. No further work needed. + self.drop_reference(); + return; + } + + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(&self.core().stage); + self.complete(); + } + + pub(super) fn dealloc(self) { + // Release the join waker, if there is one. + self.trailer().waker.with_mut(drop); + + // Check causality + self.core().stage.with_mut(drop); + + unsafe { + drop(Box::from_raw(self.cell.as_ptr())); + } + } + + // ===== join handle ===== + + /// Read the task output into `dst`. + pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) { + if can_read_output(self.header(), self.trailer(), waker) { + *dst = Poll::Ready(self.core().stage.take_output()); + } + } + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + can_read_output(self.header(), self.trailer(), waker) + } + + pub(super) fn drop_join_handle_slow(self) { + // Try to unset `JOIN_INTEREST`. This must be done as a first step in + // case the task concurrently completed. + if self.header().state.unset_join_interested().is_err() { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + } + + // Drop the `JoinHandle` reference, possibly deallocating the task + self.drop_reference(); + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(self) { + if self.header().state.transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + } + + // ===== waker behavior ===== + + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. + pub(super) fn wake_by_val(self) { + use super::state::TransitionToNotifiedByVal; + + match self.header().state.transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } + } + + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. + pub(super) fn wake_by_ref(&self) { + use super::state::TransitionToNotifiedByRef; + + match self.header().state.transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + TransitionToNotifiedByRef::DoNothing => {} + } + } + + pub(super) fn drop_reference(self) { + if self.header().state.ref_dec() { + self.dealloc(); + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) fn id(&self) -> Option<&tracing::Id> { + self.header().id.as_ref() + } + + // ====== internal ====== + + /// Completes the task. This method assumes that the state is RUNNING. + fn complete(self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + + let snapshot = self.header().state.transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. + self.core().stage.drop_future_or_output(); + } else if snapshot.has_join_waker() { + // Notify the join handle. The previous transition obtains the + // lock on the waker cell. + self.trailer().wake_join(); + } + })); + + // The task has completed execution and will no longer be scheduled. + let num_release = self.release(); + + if self.header().state.transition_to_terminal(num_release) { + self.dealloc(); + } + } + + /// Releases the task from the scheduler. Returns the number of ref-counts + /// that should be decremented. + fn release(&self) -> usize { + // We don't actually increment the ref-count here, but the new task is + // never destroyed, so that's ok. + let me = ManuallyDrop::new(self.get_new_task()); + + if let Some(task) = self.core().scheduler.release(&me) { + mem::forget(task); + 2 + } else { + 1 + } + } + + /// Creates a new task that holds its own ref-count. + /// + /// # Safety + /// + /// Any use of `self` after this call must ensure that a ref-count to the + /// task holds the task alive until after the use of `self`. Passing the + /// returned Task to any method on `self` is unsound if dropping the Task + /// could drop `self` before the call on `self` returned. + fn get_new_task(&self) -> Task<S> { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { Task::from_raw(self.cell.cast()) } + } +} + +fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { + // Load a snapshot of the current task state + let snapshot = header.state.load(); + + debug_assert!(snapshot.is_join_interested()); + + if !snapshot.is_complete() { + // The waker must be stored in the task struct. + let res = if snapshot.has_join_waker() { + // There already is a waker stored in the struct. If it matches + // the provided waker, then there is no further work to do. + // Otherwise, the waker must be swapped. + let will_wake = unsafe { + // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE` + // may mutate the `waker` field. + trailer.will_wake(waker) + }; + + if will_wake { + // The task is not complete **and** the waker is up to date, + // there is nothing further that needs to be done. + return false; + } + + // Unset the `JOIN_WAKER` to gain mutable access to the `waker` + // field then update the field with the new join worker. + // + // This requires two atomic operations, unsetting the bit and + // then resetting it. If the task transitions to complete + // concurrently to either one of those operations, then setting + // the join waker fails and we proceed to reading the task + // output. + header + .state + .unset_waker() + .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot)) + } else { + set_join_waker(header, trailer, waker.clone(), snapshot) + }; + + match res { + Ok(_) => return false, + Err(snapshot) => { + assert!(snapshot.is_complete()); + } + } + } + true +} + +fn set_join_waker( + header: &Header, + trailer: &Trailer, + waker: Waker, + snapshot: Snapshot, +) -> Result<Snapshot, Snapshot> { + assert!(snapshot.is_join_interested()); + assert!(!snapshot.has_join_waker()); + + // Safety: Only the `JoinHandle` may set the `waker` field. When + // `JOIN_INTEREST` is **not** set, nothing else will touch the field. + unsafe { + trailer.set_waker(Some(waker)); + } + + // Update the `JoinWaker` state accordingly + let res = header.state.set_join_waker(); + + // If the state could not be updated, then clear the join waker + if res.is_err() { + unsafe { + trailer.set_waker(None); + } + } + + res +} + +enum PollFuture { + Complete, + Notified, + Done, + Dealloc, +} + +/// Cancels the task and store the appropriate error in the stage field. +fn cancel_task<T: Future>(stage: &CoreStage<T>) { + // Drop the future from a panic guard. + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + stage.drop_future_or_output(); + })); + + match res { + Ok(()) => { + stage.store_output(Err(JoinError::cancelled())); + } + Err(panic) => { + stage.store_output(Err(JoinError::panic(panic))); + } + } +} + +/// Polls the future. If the future completes, the output is written to the +/// stage field. +fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage<T>, + } + impl<'a, T: Future> Drop for Guard<'a, T> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + self.core.drop_future_or_output(); + } + } + let guard = Guard { core }; + let res = guard.core.poll(cx); + mem::forget(guard); + res + })); + + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(JoinError::panic(panic)), + }; + + // Catch and ignore panics if the future panics on drop. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + core.store_output(output); + })); + + Poll::Ready(()) +} diff --git a/third_party/rust/tokio/src/runtime/task/inject.rs b/third_party/rust/tokio/src/runtime/task/inject.rs new file mode 100644 index 0000000000..1585e13a01 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/inject.rs @@ -0,0 +1,220 @@ +//! Inject queue used to send wakeups to a work-stealing scheduler + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::runtime::task; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{Acquire, Release}; + +/// Growable, MPMC queue used to inject new tasks into the scheduler and as an +/// overflow queue when the local, fixed-size, array queue overflows. +pub(crate) struct Inject<T: 'static> { + /// Pointers to the head and tail of the queue. + pointers: Mutex<Pointers>, + + /// Number of pending tasks in the queue. This helps prevent unnecessary + /// locking in the hot path. + len: AtomicUsize, + + _p: PhantomData<T>, +} + +struct Pointers { + /// True if the queue is closed. + is_closed: bool, + + /// Linked-list head. + head: Option<NonNull<task::Header>>, + + /// Linked-list tail. + tail: Option<NonNull<task::Header>>, +} + +unsafe impl<T> Send for Inject<T> {} +unsafe impl<T> Sync for Inject<T> {} + +impl<T: 'static> Inject<T> { + pub(crate) fn new() -> Inject<T> { + Inject { + pointers: Mutex::new(Pointers { + is_closed: false, + head: None, + tail: None, + }), + len: AtomicUsize::new(0), + _p: PhantomData, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Closes the injection queue, returns `true` if the queue is open when the + /// transition is made. + pub(crate) fn close(&self) -> bool { + let mut p = self.pointers.lock(); + + if p.is_closed { + return false; + } + + p.is_closed = true; + true + } + + pub(crate) fn is_closed(&self) -> bool { + self.pointers.lock().is_closed + } + + pub(crate) fn len(&self) -> usize { + self.len.load(Acquire) + } + + /// Pushes a value into the queue. + /// + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified<T>) { + // Acquire queue lock + let mut p = self.pointers.lock(); + + if p.is_closed { + return; + } + + // safety: only mutated with the lock held + let len = unsafe { self.len.unsync_load() }; + let task = task.into_raw(); + + // The next pointer should already be null + debug_assert!(get_next(task).is_none()); + + if let Some(tail) = p.tail { + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(tail, Some(task)); + } else { + p.head = Some(task); + } + + p.tail = Some(task); + + self.len.store(len + 1, Release); + } + + /// Pushes several values into the queue. + #[inline] + pub(crate) fn push_batch<I>(&self, mut iter: I) + where + I: Iterator<Item = task::Notified<T>>, + { + let first = match iter.next() { + Some(first) => first.into_raw(), + None => return, + }; + + // Link up all the tasks. + let mut prev = first; + let mut counter = 1; + + // We are going to be called with an `std::iter::Chain`, and that + // iterator overrides `for_each` to something that is easier for the + // compiler to optimize than a loop. + iter.for_each(|next| { + let next = next.into_raw(); + + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(prev, Some(next)); + prev = next; + counter += 1; + }); + + // Now that the tasks are linked together, insert them into the + // linked list. + self.push_batch_inner(first, prev, counter); + } + + /// Inserts several tasks that have been linked together into the queue. + /// + /// The provided head and tail may be be the same task. In this case, a + /// single task is inserted. + #[inline] + fn push_batch_inner( + &self, + batch_head: NonNull<task::Header>, + batch_tail: NonNull<task::Header>, + num: usize, + ) { + debug_assert!(get_next(batch_tail).is_none()); + + let mut p = self.pointers.lock(); + + if let Some(tail) = p.tail { + set_next(tail, Some(batch_head)); + } else { + p.head = Some(batch_head); + } + + p.tail = Some(batch_tail); + + // Increment the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + let len = unsafe { self.len.unsync_load() }; + + self.len.store(len + num, Release); + } + + pub(crate) fn pop(&self) -> Option<task::Notified<T>> { + // Fast path, if len == 0, then there are no values + if self.is_empty() { + return None; + } + + let mut p = self.pointers.lock(); + + // It is possible to hit null here if another thread popped the last + // task between us checking `len` and acquiring the lock. + let task = p.head?; + + p.head = get_next(task); + + if p.head.is_none() { + p.tail = None; + } + + set_next(task, None); + + // Decrement the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + self.len + .store(unsafe { self.len.unsync_load() } - 1, Release); + + // safety: a `Notified` is pushed into the queue and now it is popped! + Some(unsafe { task::Notified::from_raw(task) }) + } +} + +impl<T: 'static> Drop for Inject<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { + unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } +} + +fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { + unsafe { + header.as_ref().set_next(val); + } +} diff --git a/third_party/rust/tokio/src/runtime/task/join.rs b/third_party/rust/tokio/src/runtime/task/join.rs new file mode 100644 index 0000000000..8beed2eaac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/join.rs @@ -0,0 +1,275 @@ +use crate::runtime::task::RawTask; + +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; + +cfg_rt! { + /// An owned permission to join on a task (await its termination). + /// + /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for + /// a task rather than a thread. + /// + /// A `JoinHandle` *detaches* the associated task when it is dropped, which + /// means that there is no longer any handle to the task, and no way to `join` + /// on it. + /// + /// This `struct` is created by the [`task::spawn`] and [`task::spawn_blocking`] + /// functions. + /// + /// # Examples + /// + /// Creation from [`task::spawn`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn(async { + /// // some work here + /// }); + /// # } + /// ``` + /// + /// Creation from [`task::spawn_blocking`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn_blocking(|| { + /// // some blocking work here + /// }); + /// # } + /// ``` + /// + /// The generic parameter `T` in `JoinHandle<T>` is the return type of the spawned task. + /// If the return value is an i32, the join handle has type `JoinHandle<i32>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<i32> = task::spawn(async { + /// 5 + 3 + /// }); + /// # } + /// + /// ``` + /// + /// If the task does not have a return value, the join handle has type `JoinHandle<()>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<()> = task::spawn(async { + /// println!("I return nothing."); + /// }); + /// # } + /// ``` + /// + /// Note that `handle.await` doesn't give you the return type directly. It is wrapped in a + /// `Result` because panics in the spawned task are caught by Tokio. The `?` operator has + /// to be double chained to extract the returned value: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// Ok(5 + 3) + /// }); + /// + /// let result = join_handle.await??; + /// assert_eq!(result, 8); + /// Ok(()) + /// } + /// ``` + /// + /// If the task panics, the error is a [`JoinError`] that contains the panic: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// panic!("boom"); + /// }); + /// + /// let err = join_handle.await.unwrap_err(); + /// assert!(err.is_panic()); + /// Ok(()) + /// } + /// + /// ``` + /// Child being detached and outliving its parent: + /// + /// ```no_run + /// use tokio::task; + /// use tokio::time; + /// use std::time::Duration; + /// + /// # #[tokio::main] async fn main() { + /// let original_task = task::spawn(async { + /// let _detached_task = task::spawn(async { + /// // Here we sleep to make sure that the first task returns before. + /// time::sleep(Duration::from_millis(10)).await; + /// // This will be called, even though the JoinHandle is dropped. + /// println!("♫ Still alive ♫"); + /// }); + /// }); + /// + /// original_task.await.expect("The task being joined has panicked"); + /// println!("Original task is joined."); + /// + /// // We make sure that the new task has time to run, before the main + /// // task returns. + /// + /// time::sleep(Duration::from_millis(1000)).await; + /// # } + /// ``` + /// + /// [`task::spawn`]: crate::task::spawn() + /// [`task::spawn_blocking`]: crate::task::spawn_blocking + /// [`std::thread::JoinHandle`]: std::thread::JoinHandle + /// [`JoinError`]: crate::task::JoinError + pub struct JoinHandle<T> { + raw: Option<RawTask>, + _p: PhantomData<T>, + } +} + +unsafe impl<T: Send> Send for JoinHandle<T> {} +unsafe impl<T: Send> Sync for JoinHandle<T> {} + +impl<T> UnwindSafe for JoinHandle<T> {} +impl<T> RefUnwindSafe for JoinHandle<T> {} + +impl<T> JoinHandle<T> { + pub(super) fn new(raw: RawTask) -> JoinHandle<T> { + JoinHandle { + raw: Some(raw), + _p: PhantomData, + } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// true + /// })); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// false + /// })); + /// + /// for handle in &handles { + /// handle.abort(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + /// [cancelled]: method@super::error::JoinError::is_cancelled + pub fn abort(&self) { + if let Some(raw) = self.raw { + raw.remote_abort(); + } + } + + /// Set the waker that is notified when the task completes. + pub(crate) fn set_join_waker(&mut self, waker: &Waker) { + if let Some(raw) = self.raw { + if raw.try_set_join_waker(waker) { + // In this case the task has already completed. We wake the waker immediately. + waker.wake_by_ref(); + } + } + } +} + +impl<T> Unpin for JoinHandle<T> {} + +impl<T> Future for JoinHandle<T> { + type Output = super::Result<T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut ret = Poll::Pending; + + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + // Raw should always be set. If it is not, this is due to polling after + // completion + let raw = self + .raw + .as_ref() + .expect("polling after `JoinHandle` already completed"); + + // Try to read the task output. If the task is not yet complete, the + // waker is stored and is notified once the task does complete. + // + // The function must go via the vtable, which requires erasing generic + // types. To do this, the function "return" is placed on the stack + // **before** calling the function and is passed into the function using + // `*mut ()`. + // + // Safety: + // + // The type of `T` must match the task's output type. + unsafe { + raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); + } + + if ret.is_ready() { + coop.made_progress(); + } + + ret + } +} + +impl<T> Drop for JoinHandle<T> { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + if raw.header().state.drop_join_handle_fast().is_ok() { + return; + } + + raw.drop_join_handle_slow(); + } + } +} + +impl<T> fmt::Debug for JoinHandle<T> +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("JoinHandle").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/list.rs b/third_party/rust/tokio/src/runtime/task/list.rs new file mode 100644 index 0000000000..7758f8db7a --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/list.rs @@ -0,0 +1,297 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::Mutex; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; +use crate::util::linked_list::{Link, LinkedList}; + +use std::marker::PhantomData; + +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + +pub(crate) struct OwnedTasks<S: 'static> { + inner: Mutex<OwnedTasksInner<S>>, + id: u64, +} +pub(crate) struct LocalOwnedTasks<S: 'static> { + inner: UnsafeCell<OwnedTasksInner<S>>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, +} +struct OwnedTasksInner<S: 'static> { + list: LinkedList<Task<S>, <Task<S> as Link>::Target>, + closed: bool, +} + +impl<S: 'static> OwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + } + } + + /// Binds the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(notified); + task.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } + } + + /// Asserts that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + // The first iteration of the loop was unrolled so it can set the + // closed bool. + let first_task = { + let mut lock = self.inner.lock(); + lock.closed = true; + lock.list.pop_back() + }; + match first_task { + Some(task) => task.shutdown(), + None => return, + } + + loop { + let task = match self.inner.lock().list.pop_back() { + Some(task) => task, + None => return, + }; + + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } +} + +impl<S: 'static> LocalOwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + _not_send_or_sync: PhantomData, + } + } + + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + if self.is_closed() { + drop(notified); + task.shutdown(); + (join, None) + } else { + self.with_inner(|inner| { + inner.list.push_front(task); + }); + (join, Some(notified)) + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + self.with_inner(|inner| inner.closed = true); + + while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + self.with_inner(|inner| + // safety: We just checked that the provided task is not in some + // other linked list. + unsafe { inner.list.remove(task.header().into()) }) + } + + /// Asserts that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + #[inline] + fn with_inner<F, T>(&self, f: F) -> T + where + F: FnOnce(&mut OwnedTasksInner<S>) -> T, + { + // safety: This type is not Sync, so concurrent calls of this method + // can't happen. Furthermore, all uses of this method in this file make + // sure that they don't call `with_inner` recursively. + self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) + } + + pub(crate) fn is_closed(&self) -> bool { + self.with_inner(|inner| inner.closed) + } + + pub(crate) fn is_empty(&self) -> bool { + self.with_inner(|inner| inner.list.is_empty()) + } +} + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/third_party/rust/tokio/src/runtime/task/mod.rs b/third_party/rust/tokio/src/runtime/task/mod.rs new file mode 100644 index 0000000000..2a492dc985 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/mod.rs @@ -0,0 +1,445 @@ +//! The task module. +//! +//! The task module contains the code that manages spawned tasks and provides a +//! safe API for the rest of the runtime to use. Each task in a runtime is +//! stored in an OwnedTasks or LocalOwnedTasks object. +//! +//! # Task reference types +//! +//! A task is usually referenced by multiple handles, and there are several +//! types of handles. +//! +//! * OwnedTask - tasks stored in an OwnedTasks or LocalOwnedTasks are of this +//! reference type. +//! +//! * JoinHandle - each task has a JoinHandle that allows access to the output +//! of the task. +//! +//! * Waker - every waker for a task has this reference type. There can be any +//! number of waker references. +//! +//! * Notified - tracks whether the task is notified. +//! +//! * Unowned - this task reference type is used for tasks not stored in any +//! runtime. Mainly used for blocking tasks, but also in tests. +//! +//! The task uses a reference count to keep track of how many active references +//! exist. The Unowned reference type takes up two ref-counts. All other +//! reference types take up a single ref-count. +//! +//! Besides the waker type, each task has at most one of each reference type. +//! +//! # State +//! +//! The task stores its state in an atomic usize with various bitfields for the +//! necessary information. The state has the following bitfields: +//! +//! * RUNNING - Tracks whether the task is currently being polled or cancelled. +//! This bit functions as a lock around the task. +//! +//! * COMPLETE - Is one once the future has fully completed and has been +//! dropped. Never unset once set. Never set together with RUNNING. +//! +//! * NOTIFIED - Tracks whether a Notified object currently exists. +//! +//! * CANCELLED - Is set to one for tasks that should be cancelled as soon as +//! possible. May take any value for completed tasks. +//! +//! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. +//! +//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! +//! The rest of the bits are used for the ref-count. +//! +//! # Fields in the task +//! +//! The task has various fields. This section describes how and when it is safe +//! to access a field. +//! +//! * The state field is accessed with atomic instructions. +//! +//! * The OwnedTask reference has exclusive access to the `owned` field. +//! +//! * The Notified reference has exclusive access to the `queue_next` field. +//! +//! * The `owner_id` field can be set as part of construction of the task, but +//! is otherwise immutable and anyone can access the field immutably without +//! synchronization. +//! +//! * If COMPLETE is one, then the JoinHandle has exclusive access to the +//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as +//! a lock for the stage field, and it can be accessed only by the thread +//! that set RUNNING to one. +//! +//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the +//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the +//! thread that set COMPLETE to one has exclusive access to the join handle +//! waker. +//! +//! All other fields are immutable and can be accessed immutably without +//! synchronization by anyone. +//! +//! # Safety +//! +//! This section goes through various situations and explains why the API is +//! safe in that situation. +//! +//! ## Polling or dropping the future +//! +//! Any mutable access to the future happens after obtaining a lock by modifying +//! the RUNNING field, so exclusive access is ensured. +//! +//! When the task completes, exclusive access to the output is transferred to +//! the JoinHandle. If the JoinHandle is already dropped when the transition to +//! complete happens, the thread performing that transition retains exclusive +//! access to the output and should immediately drop it. +//! +//! ## Non-Send futures +//! +//! If a future is not Send, then it is bound to a LocalOwnedTasks. The future +//! will only ever be polled or dropped given a LocalNotified or inside a call +//! to LocalOwnedTasks::shutdown_all. In either case, it is guaranteed that the +//! future is on the right thread. +//! +//! If the task is never removed from the LocalOwnedTasks, then it is leaked, so +//! there is no risk that the task is dropped on some other thread when the last +//! ref-count drops. +//! +//! ## Non-Send output +//! +//! When a task completes, the output is placed in the stage of the task. Then, +//! a transition that sets COMPLETE to true is performed, and the value of +//! JOIN_INTEREST when this transition happens is read. +//! +//! If JOIN_INTEREST is zero when the transition to COMPLETE happens, then the +//! output is immediately dropped. +//! +//! If JOIN_INTEREST is one when the transition to COMPLETE happens, then the +//! JoinHandle is responsible for cleaning up the output. If the output is not +//! Send, then this happens: +//! +//! 1. The output is created on the thread that the future was polled on. Since +//! only non-Send futures can have non-Send output, the future was polled on +//! the thread that the future was spawned from. +//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! JoinHandle is also on the thread that the future was spawned from. +//! 3. Thus, the JoinHandle will not move the output across threads when it +//! takes or drops the output. +//! +//! ## Recursive poll/shutdown +//! +//! Calling poll from inside a shutdown call or vice-versa is not prevented by +//! the API exposed by the task module, so this has to be safe. In either case, +//! the lock in the RUNNING bitfield makes the inner call return immediately. If +//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the +//! poll call will notice it when the poll finishes, and the task is cancelled +//! at that point. + +// Some task infrastructure is here to support `JoinSet`, which is currently +// unstable. This should be removed once `JoinSet` is stabilized. +#![cfg_attr(not(tokio_unstable), allow(dead_code))] + +mod core; +use self::core::Cell; +use self::core::Header; + +mod error; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::error::JoinError; + +mod harness; +use self::harness::Harness; + +cfg_rt_multi_thread! { + mod inject; + pub(super) use self::inject::Inject; +} + +mod join; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::join::JoinHandle; + +mod list; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; + +mod raw; +use self::raw::RawTask; + +mod state; +use self::state::State; + +mod waker; + +use crate::future::Future; +use crate::util::linked_list; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::{fmt, mem}; + +/// An owned handle to the task, tracked by ref count. +#[repr(transparent)] +pub(crate) struct Task<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +unsafe impl<S> Send for Task<S> {} +unsafe impl<S> Sync for Task<S> {} + +/// A task was notified. +#[repr(transparent)] +pub(crate) struct Notified<S: 'static>(Task<S>); + +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. +unsafe impl<S: Schedule> Send for Notified<S> {} +unsafe impl<S: Schedule> Sync for Notified<S> {} + +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified<S: 'static> { + task: Task<S>, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +// safety: This type can only be created given a Send task. +unsafe impl<S> Send for UnownedTask<S> {} +unsafe impl<S> Sync for UnownedTask<S> {} + +/// Task result sent back. +pub(crate) type Result<T> = std::result::Result<T, JoinError>; + +pub(crate) trait Schedule: Sync + Sized + 'static { + /// The task has completed work and is ready to be released. The scheduler + /// should release it immediately and return it. The task module will batch + /// the ref-dec with setting other options. + /// + /// If the scheduler has already released the task, then None is returned. + fn release(&self, task: &Task<Self>) -> Option<Task<Self>>; + + /// Schedule the task + fn schedule(&self, task: Notified<Self>); + + /// Schedule the task to run in the near future, yielding the thread to + /// other tasks. + fn yield_now(&self, task: Notified<Self>) { + self.schedule(task); + } +} + +cfg_rt! { + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task<T, S>( + task: T, + scheduler: S + ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let raw = RawTask::new::<T, S>(task, scheduler); + let task = Task { + raw, + _p: PhantomData, + }; + let notified = Notified(Task { + raw, + _p: PhantomData, + }); + let join = JoinHandle::new(raw); + + (task, notified, join) + } + + /// Creates a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks use this method. + pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = new_task(task, scheduler); + + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, + _p: PhantomData, + }; + std::mem::forget(task); + std::mem::forget(notified); + + (unowned, join) + } +} + +impl<S: 'static> Task<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task { + raw: RawTask::from_raw(ptr), + _p: PhantomData, + } + } + + fn header(&self) -> &Header { + self.raw.header() + } +} + +impl<S: 'static> Notified<S> { + fn header(&self) -> &Header { + self.0.header() + } +} + +cfg_rt_multi_thread! { + impl<S: 'static> Notified<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { + Notified(Task::from_raw(ptr)) + } + } + + impl<S: 'static> Task<S> { + fn into_raw(self) -> NonNull<Header> { + let ret = self.raw.header_ptr(); + mem::forget(self); + ret + } + } + + impl<S: 'static> Notified<S> { + fn into_raw(self) -> NonNull<Header> { + self.0.into_raw() + } + } +} + +impl<S: Schedule> Task<S> { + /// Pre-emptively cancels the task as part of the shutdown process. + pub(crate) fn shutdown(self) { + let raw = self.raw; + mem::forget(self); + raw.shutdown(); + } +} + +impl<S: Schedule> LocalNotified<S> { + /// Runs the task. + pub(crate) fn run(self) { + let raw = self.task.raw; + mem::forget(self); + raw.poll(); + } +} + +impl<S: Schedule> UnownedTask<S> { + // Used in test of the inject queue. + #[cfg(test)] + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + pub(super) fn into_notified(self) -> Notified<S> { + Notified(self.into_task()) + } + + fn into_task(self) -> Task<S> { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + let raw = self.raw; + mem::forget(self); + + // Transfer one ref-count to a Task object. + let task = Task::<S> { + raw, + _p: PhantomData, + }; + + // Use the other ref-count to poll the task. + raw.poll(); + // Decrement our extra ref-count + drop(task); + } + + pub(crate) fn shutdown(self) { + self.into_task().shutdown() + } +} + +impl<S: 'static> Drop for Task<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.header().state.ref_dec() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S: 'static> Drop for UnownedTask<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S> fmt::Debug for Task<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "Task({:p})", self.header()) + } +} + +impl<S> fmt::Debug for Notified<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "task::Notified({:p})", self.0.header()) + } +} + +/// # Safety +/// +/// Tasks are pinned. +unsafe impl<S> linked_list::Link for Task<S> { + type Handle = Task<S>; + type Target = Header; + + fn as_raw(handle: &Task<S>) -> NonNull<Header> { + handle.raw.header_ptr() + } + + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task::from_raw(ptr) + } + + unsafe fn pointers(target: NonNull<Header>) -> NonNull<linked_list::Pointers<Header>> { + // Not super great as it avoids some of looms checking... + NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/raw.rs b/third_party/rust/tokio/src/runtime/task/raw.rs new file mode 100644 index 0000000000..2e4420b5c1 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/raw.rs @@ -0,0 +1,165 @@ +use crate::future::Future; +use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; + +use std::ptr::NonNull; +use std::task::{Poll, Waker}; + +/// Raw task handle +pub(super) struct RawTask { + ptr: NonNull<Header>, +} + +pub(super) struct Vtable { + /// Polls the future. + pub(super) poll: unsafe fn(NonNull<Header>), + + /// Deallocates the memory. + pub(super) dealloc: unsafe fn(NonNull<Header>), + + /// Reads the task output, if complete. + pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker), + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool, + + /// The join handle has been dropped. + pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), + + /// The task is remotely aborted. + pub(super) remote_abort: unsafe fn(NonNull<Header>), + + /// Scheduler is being shutdown. + pub(super) shutdown: unsafe fn(NonNull<Header>), +} + +/// Get the vtable for the requested `T` and `S` generics. +pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { + &Vtable { + poll: poll::<T, S>, + dealloc: dealloc::<T, S>, + try_read_output: try_read_output::<T, S>, + try_set_join_waker: try_set_join_waker::<T, S>, + drop_join_handle_slow: drop_join_handle_slow::<T, S>, + remote_abort: remote_abort::<T, S>, + shutdown: shutdown::<T, S>, + } +} + +impl RawTask { + pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask + where + T: Future, + S: Schedule, + { + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; + + RawTask { ptr } + } + + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask { + RawTask { ptr } + } + + pub(super) fn header_ptr(&self) -> NonNull<Header> { + self.ptr + } + + /// Returns a reference to the task's meta structure. + /// + /// Safe as `Header` is `Sync`. + pub(super) fn header(&self) -> &Header { + unsafe { self.ptr.as_ref() } + } + + /// Safety: mutual exclusion is required to call this function. + pub(super) fn poll(self) { + let vtable = self.header().vtable; + unsafe { (vtable.poll)(self.ptr) } + } + + pub(super) fn dealloc(self) { + let vtable = self.header().vtable; + unsafe { + (vtable.dealloc)(self.ptr); + } + } + + /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T` + /// is the future stored by the task. + pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) { + let vtable = self.header().vtable; + (vtable.try_read_output)(self.ptr, dst, waker); + } + + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + let vtable = self.header().vtable; + unsafe { (vtable.try_set_join_waker)(self.ptr, waker) } + } + + pub(super) fn drop_join_handle_slow(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_join_handle_slow)(self.ptr) } + } + + pub(super) fn shutdown(self) { + let vtable = self.header().vtable; + unsafe { (vtable.shutdown)(self.ptr) } + } + + pub(super) fn remote_abort(self) { + let vtable = self.header().vtable; + unsafe { (vtable.remote_abort)(self.ptr) } + } +} + +impl Clone for RawTask { + fn clone(&self) -> Self { + RawTask { ptr: self.ptr } + } +} + +impl Copy for RawTask {} + +unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.poll(); +} + +unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.dealloc(); +} + +unsafe fn try_read_output<T: Future, S: Schedule>( + ptr: NonNull<Header>, + dst: *mut (), + waker: &Waker, +) { + let out = &mut *(dst as *mut Poll<super::Result<T::Output>>); + + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_read_output(out, waker); +} + +unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool { + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_set_join_waker(waker) +} + +unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.drop_join_handle_slow() +} + +unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.remote_abort() +} + +unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.shutdown() +} diff --git a/third_party/rust/tokio/src/runtime/task/state.rs b/third_party/rust/tokio/src/runtime/task/state.rs new file mode 100644 index 0000000000..c2d5b28eac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/state.rs @@ -0,0 +1,595 @@ +use crate::loom::sync::atomic::AtomicUsize; + +use std::fmt; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::usize; + +pub(super) struct State { + val: AtomicUsize, +} + +/// Current state value. +#[derive(Copy, Clone)] +pub(super) struct Snapshot(usize); + +type UpdateResult = Result<Snapshot, Snapshot>; + +/// The task is currently being run. +const RUNNING: usize = 0b0001; + +/// The task is complete. +/// +/// Once this bit is set, it is never unset. +const COMPLETE: usize = 0b0010; + +/// Extracts the task's lifecycle value from the state. +const LIFECYCLE_MASK: usize = 0b11; + +/// Flag tracking if the task has been pushed into a run queue. +const NOTIFIED: usize = 0b100; + +/// The join handle is still around. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_INTEREST: usize = 0b1_000; + +/// A join handle waker has been set. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_WAKER: usize = 0b10_000; + +/// The task has been forcibly cancelled. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const CANCELLED: usize = 0b100_000; + +/// All bits. +const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; + +/// Bits used by the ref count portion of the state. +const REF_COUNT_MASK: usize = !STATE_MASK; + +/// Number of positions to shift the ref count. +const REF_COUNT_SHIFT: usize = REF_COUNT_MASK.count_zeros() as usize; + +/// One ref count. +const REF_ONE: usize = 1 << REF_COUNT_SHIFT; + +/// State a task is initialized with. +/// +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; + +#[must_use] +pub(super) enum TransitionToRunning { + Success, + Cancelled, + Failed, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToIdle { + Ok, + OkNotified, + OkDealloc, + Cancelled, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByVal { + DoNothing, + Submit, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByRef { + DoNothing, + Submit, +} + +/// All transitions are performed via RMW operations. This establishes an +/// unambiguous modification order. +impl State { + /// Returns a task's initial state. + pub(super) fn new() -> State { + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. + State { + val: AtomicUsize::new(INITIAL_STATE), + } + } + + /// Loads the current state, establishes `Acquire` ordering. + pub(super) fn load(&self) -> Snapshot { + Snapshot(self.val.load(Acquire)) + } + + /// Attempts to transition the lifecycle to `Running`. This sets the + /// notified bit to false so notifications during the poll can be detected. + pub(super) fn transition_to_running(&self) -> TransitionToRunning { + self.fetch_update_action(|mut next| { + let action; + assert!(next.is_notified()); + + if !next.is_idle() { + // This happens if the task is either currently running or if it + // has already completed, e.g. if it was cancelled during + // shutdown. Consume the ref-count and return. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToRunning::Dealloc; + } else { + action = TransitionToRunning::Failed; + } + } else { + // We are able to lock the RUNNING bit. + next.set_running(); + next.unset_notified(); + + if next.is_cancelled() { + action = TransitionToRunning::Cancelled; + } else { + action = TransitionToRunning::Success; + } + } + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Idle`. + /// + /// Returns `true` if the transition to `Idle` is successful, `false` otherwise. + /// The transition to `Idle` fails if the task has been flagged to be + /// cancelled. + pub(super) fn transition_to_idle(&self) -> TransitionToIdle { + self.fetch_update_action(|curr| { + assert!(curr.is_running()); + + if curr.is_cancelled() { + return (TransitionToIdle::Cancelled, None); + } + + let mut next = curr; + let action; + next.unset_running(); + + if !next.is_notified() { + // Polling the future consumes the ref-count of the Notified. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToIdle::OkDealloc; + } else { + action = TransitionToIdle::Ok; + } + } else { + // The caller will schedule a new notification, so we create a + // new ref-count for the notification. Our own ref-count is kept + // for now, and the caller will drop it shortly. + next.ref_inc(); + action = TransitionToIdle::OkNotified; + } + + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Complete`. + pub(super) fn transition_to_complete(&self) -> Snapshot { + const DELTA: usize = RUNNING | COMPLETE; + + let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); + assert!(prev.is_running()); + assert!(!prev.is_complete()); + + Snapshot(prev.0 ^ DELTA) + } + + /// Transitions from `Complete` -> `Terminal`, decrementing the reference + /// count the specified number of times. + /// + /// Returns true if the task should be deallocated. + pub(super) fn transition_to_terminal(&self, count: usize) -> bool { + let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); + assert!( + prev.ref_count() >= count, + "current: {}, sub: {}", + prev.ref_count(), + count + ); + prev.ref_count() == count + } + + /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. + pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { + self.fetch_update_action(|mut snapshot| { + let action; + + if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit anything as the thread currently running the + // future is responsible for that. + snapshot.set_notified(); + snapshot.ref_dec(); + + // The thread that set the running bit also holds a ref-count. + assert!(snapshot.ref_count() > 0); + + action = TransitionToNotifiedByVal::DoNothing; + } else if snapshot.is_complete() || snapshot.is_notified() { + // We do not need to submit any notifications, but we have to + // decrement the ref-count. + snapshot.ref_dec(); + + if snapshot.ref_count() == 0 { + action = TransitionToNotifiedByVal::Dealloc; + } else { + action = TransitionToNotifiedByVal::DoNothing; + } + } else { + // We create a new notified that we can submit. The caller + // retains ownership of the ref-count they passed in. + snapshot.set_notified(); + snapshot.ref_inc(); + action = TransitionToNotifiedByVal::Submit; + } + + (action, Some(snapshot)) + }) + } + + /// Transitions the state to `NOTIFIED`. + pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_complete() || snapshot.is_notified() { + // There is nothing to do in this case. + (TransitionToNotifiedByRef::DoNothing, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit as the thread currently running the future is + // responsible for that. + snapshot.set_notified(); + (TransitionToNotifiedByRef::DoNothing, Some(snapshot)) + } else { + // The task is idle and not notified. We should submit a + // notification. + snapshot.set_notified(); + snapshot.ref_inc(); + (TransitionToNotifiedByRef::Submit, Some(snapshot)) + } + }) + } + + /// Sets the cancelled bit and transitions the state to `NOTIFIED` if idle. + /// + /// Returns `true` if the task needs to be submitted to the pool for + /// execution. + pub(super) fn transition_to_notified_and_cancel(&self) -> bool { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_cancelled() || snapshot.is_complete() { + // Aborts to completed or cancelled tasks are no-ops. + (false, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as cancelled. The thread + // running the task will notice the cancelled bit when it + // stops polling and it will kill the task. + // + // The set_notified() call is not strictly necessary but it will + // in some cases let a wake_by_ref call return without having + // to perform a compare_exchange. + snapshot.set_notified(); + snapshot.set_cancelled(); + (false, Some(snapshot)) + } else { + // The task is idle. We set the cancelled and notified bits and + // submit a notification if the notified bit was not already + // set. + snapshot.set_cancelled(); + if !snapshot.is_notified() { + snapshot.set_notified(); + snapshot.ref_inc(); + (true, Some(snapshot)) + } else { + (false, Some(snapshot)) + } + } + }) + } + + /// Sets the `CANCELLED` bit and attempts to transition to `Running`. + /// + /// Returns `true` if the transition to `Running` succeeded. + pub(super) fn transition_to_shutdown(&self) -> bool { + let mut prev = Snapshot(0); + + let _ = self.fetch_update(|mut snapshot| { + prev = snapshot; + + if snapshot.is_idle() { + snapshot.set_running(); + } + + // If the task was not idle, the thread currently running the task + // will notice the cancelled bit and cancel it once the poll + // completes. + snapshot.set_cancelled(); + Some(snapshot) + }); + + prev.is_idle() + } + + /// Optimistically tries to swap the state assuming the join handle is + /// __immediately__ dropped on spawn. + pub(super) fn drop_join_handle_fast(&self) -> Result<(), ()> { + use std::sync::atomic::Ordering::Relaxed; + + // Relaxed is acceptable as if this function is called and succeeds, + // then nothing has been done w/ the join handle. + // + // The moment the join handle is used (polled), the `JOIN_WAKER` flag is + // set, at which point the CAS will fail. + // + // Given this, there is no risk if this operation is reordered. + self.val + .compare_exchange_weak( + INITIAL_STATE, + (INITIAL_STATE - REF_ONE) & !JOIN_INTEREST, + Release, + Relaxed, + ) + .map(|_| ()) + .map_err(|_| ()) + } + + /// Tries to unset the JOIN_INTEREST flag. + /// + /// Returns `Ok` if the operation happens before the task transitions to a + /// completed state, `Err` otherwise. + pub(super) fn unset_join_interested(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_interested(); + + Some(next) + }) + } + + /// Sets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn set_join_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(!curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.set_join_waker(); + + Some(next) + }) + } + + /// Unsets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` has been unset, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn unset_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_waker(); + + Some(next) + }) + } + + pub(super) fn ref_inc(&self) { + use std::process; + use std::sync::atomic::Ordering::Relaxed; + + // Using a relaxed ordering is alright here, as knowledge of the + // original reference prevents other threads from erroneously deleting + // the object. + // + // As explained in the [Boost documentation][1], Increasing the + // reference counter can always be done with memory_order_relaxed: New + // references to an object can only be formed from an existing + // reference, and passing an existing reference from one thread to + // another must already provide any required synchronization. + // + // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html) + let prev = self.val.fetch_add(REF_ONE, Relaxed); + + // If the reference count overflowed, abort. + if prev > isize::MAX as usize { + process::abort(); + } + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 1); + prev.ref_count() == 1 + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 2); + prev.ref_count() == 2 + } + + fn fetch_update_action<F, T>(&self, mut f: F) -> T + where + F: FnMut(Snapshot) -> (T, Option<Snapshot>), + { + let mut curr = self.load(); + + loop { + let (output, next) = f(curr); + let next = match next { + Some(next) => next, + None => return output, + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return output, + Err(actual) => curr = Snapshot(actual), + } + } + } + + fn fetch_update<F>(&self, mut f: F) -> Result<Snapshot, Snapshot> + where + F: FnMut(Snapshot) -> Option<Snapshot>, + { + let mut curr = self.load(); + + loop { + let next = match f(curr) { + Some(next) => next, + None => return Err(curr), + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return Ok(next), + Err(actual) => curr = Snapshot(actual), + } + } + } +} + +// ===== impl Snapshot ===== + +impl Snapshot { + /// Returns `true` if the task is in an idle state. + pub(super) fn is_idle(self) -> bool { + self.0 & (RUNNING | COMPLETE) == 0 + } + + /// Returns `true` if the task has been flagged as notified. + pub(super) fn is_notified(self) -> bool { + self.0 & NOTIFIED == NOTIFIED + } + + fn unset_notified(&mut self) { + self.0 &= !NOTIFIED + } + + fn set_notified(&mut self) { + self.0 |= NOTIFIED + } + + pub(super) fn is_running(self) -> bool { + self.0 & RUNNING == RUNNING + } + + fn set_running(&mut self) { + self.0 |= RUNNING; + } + + fn unset_running(&mut self) { + self.0 &= !RUNNING; + } + + pub(super) fn is_cancelled(self) -> bool { + self.0 & CANCELLED == CANCELLED + } + + fn set_cancelled(&mut self) { + self.0 |= CANCELLED; + } + + /// Returns `true` if the task's future has completed execution. + pub(super) fn is_complete(self) -> bool { + self.0 & COMPLETE == COMPLETE + } + + pub(super) fn is_join_interested(self) -> bool { + self.0 & JOIN_INTEREST == JOIN_INTEREST + } + + fn unset_join_interested(&mut self) { + self.0 &= !JOIN_INTEREST + } + + pub(super) fn has_join_waker(self) -> bool { + self.0 & JOIN_WAKER == JOIN_WAKER + } + + fn set_join_waker(&mut self) { + self.0 |= JOIN_WAKER; + } + + fn unset_join_waker(&mut self) { + self.0 &= !JOIN_WAKER + } + + pub(super) fn ref_count(self) -> usize { + (self.0 & REF_COUNT_MASK) >> REF_COUNT_SHIFT + } + + fn ref_inc(&mut self) { + assert!(self.0 <= isize::MAX as usize); + self.0 += REF_ONE; + } + + pub(super) fn ref_dec(&mut self) { + assert!(self.ref_count() > 0); + self.0 -= REF_ONE + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let snapshot = self.load(); + snapshot.fmt(fmt) + } +} + +impl fmt::Debug for Snapshot { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Snapshot") + .field("is_running", &self.is_running()) + .field("is_complete", &self.is_complete()) + .field("is_notified", &self.is_notified()) + .field("is_cancelled", &self.is_cancelled()) + .field("is_join_interested", &self.is_join_interested()) + .field("has_join_waker", &self.has_join_waker()) + .field("ref_count", &self.ref_count()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/waker.rs b/third_party/rust/tokio/src/runtime/task/waker.rs new file mode 100644 index 0000000000..74a29f4a84 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/waker.rs @@ -0,0 +1,130 @@ +use crate::future::Future; +use crate::runtime::task::harness::Harness; +use crate::runtime::task::{Header, Schedule}; + +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops; +use std::ptr::NonNull; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +pub(super) struct WakerRef<'a, S: 'static> { + waker: ManuallyDrop<Waker>, + _p: PhantomData<(&'a Header, S)>, +} + +/// Returns a `WakerRef` which avoids having to pre-emptively increase the +/// refcount if there is no need to do so. +pub(super) fn waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S> +where + T: Future, + S: Schedule, +{ + // `Waker::will_wake` uses the VTABLE pointer as part of the check. This + // means that `will_wake` will always return false when using the current + // task's waker. (discussion at rust-lang/rust#66281). + // + // To fix this, we use a single vtable. Since we pass in a reference at this + // point and not an *owned* waker, we must ensure that `drop` is never + // called on this waker instance. This is done by wrapping it with + // `ManuallyDrop` and then never calling drop. + let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T, S>(*header))) }; + + WakerRef { + waker, + _p: PhantomData, + } +} + +impl<S> ops::Deref for WakerRef<'_, S> { + type Target = Waker; + + fn deref(&self) -> &Waker { + &self.waker + } +} + +cfg_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + if let Some(id) = $harness.id() { + tracing::trace!( + target: "tokio::task::waker", + op = $op, + task.id = id.into_u64(), + ); + } + } + } +} + +cfg_not_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + // noop + let _ = &$harness; + } + } +} + +unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker +where + T: Future, + S: Schedule, +{ + let header = ptr as *const Header; + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.clone"); + (*header).state.ref_inc(); + raw_waker::<T, S>(ptr) +} + +unsafe fn drop_waker<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.drop"); + harness.drop_reference(); +} + +unsafe fn wake_by_val<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake"); + harness.wake_by_val(); +} + +// Wake without consuming the waker +unsafe fn wake_by_ref<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake_by_ref"); + harness.wake_by_ref(); +} + +fn raw_waker<T, S>(header: NonNull<Header>) -> RawWaker +where + T: Future, + S: Schedule, +{ + let ptr = header.as_ptr() as *const (); + let vtable = &RawWakerVTable::new( + clone_waker::<T, S>, + wake_by_val::<T, S>, + wake_by_ref::<T, S>, + drop_waker::<T, S>, + ); + RawWaker::new(ptr, vtable) +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs new file mode 100644 index 0000000000..a772603f71 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs @@ -0,0 +1,142 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::loom::thread; +use crate::runtime::{Builder, Runtime}; +use crate::sync::oneshot::{self, Receiver}; +use crate::task; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::task::{Context, Poll}; + +fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) { + let (tx, rx) = oneshot::channel(); + let num_polls = Arc::new(AtomicUsize::new(0)); + rt.spawn(async move { + for _ in 0..12 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + rt.block_on(async { + BlockedFuture { + rx, + num_polls: num_polls.clone(), + } + .await; + }); + + let polls = num_polls.load(Acquire); + assert!(polls <= at_most_polls); +} + +#[test] +fn block_on_num_polls() { + loom::model(|| { + // we expect at most 4 number of polls because there are three points at + // which we poll the future and an opportunity for a false-positive.. At + // any of these points it can be ready: + // + // - when we fail to steal the parker and we block on a notification + // that it is available. + // + // - when we steal the parker and we schedule the future + // + // - when the future is woken up and we have ran the max number of tasks + // for the current tick or there are no more tasks to run. + // + // - a thread is notified that the parker is available but a third + // thread acquires it before the notified thread can. + // + let at_most = 4; + + let rt1 = Arc::new(Builder::new_current_thread().build().unwrap()); + let rt2 = rt1.clone(); + let rt3 = rt1.clone(); + + let th1 = thread::spawn(move || assert_at_most_num_polls(rt1, at_most)); + let th2 = thread::spawn(move || assert_at_most_num_polls(rt2, at_most)); + let th3 = thread::spawn(move || assert_at_most_num_polls(rt3, at_most)); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); + }); +} + +#[test] +fn assert_no_unnecessary_polls() { + loom::model(|| { + // // After we poll outer future, woken should reset to false + let rt = Builder::new_current_thread().build().unwrap(); + let (tx, rx) = oneshot::channel(); + let pending_cnt = Arc::new(AtomicUsize::new(0)); + + rt.spawn(async move { + for _ in 0..24 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + let pending_cnt_clone = pending_cnt.clone(); + rt.block_on(async move { + // use task::yield_now() to ensure woken set to true + // ResetFuture will be polled at most once + // Here comes two cases + // 1. recv no message from channel, ResetFuture will be polled + // but get Pending and we record ResetFuture.pending_cnt ++. + // Then when message arrive, ResetFuture returns Ready. So we + // expect ResetFuture.pending_cnt = 1 + // 2. recv message from channel, ResetFuture returns Ready immediately. + // We expect ResetFuture.pending_cnt = 0 + task::yield_now().await; + ResetFuture { + rx, + pending_cnt: pending_cnt_clone, + } + .await; + }); + + let pending_cnt = pending_cnt.load(Acquire); + assert!(pending_cnt <= 1); + }); +} + +struct BlockedFuture { + rx: Receiver<()>, + num_polls: Arc<AtomicUsize>, +} + +impl Future for BlockedFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.num_polls.fetch_add(1, Release); + + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => Poll::Pending, + _ => Poll::Ready(()), + } + } +} + +struct ResetFuture { + rx: Receiver<()>, + pending_cnt: Arc<AtomicUsize>, +} + +impl Future for ResetFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => { + self.pending_cnt.fetch_add(1, Release); + Poll::Pending + } + _ => Poll::Ready(()), + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs new file mode 100644 index 0000000000..89de85e436 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs @@ -0,0 +1,81 @@ +use crate::runtime::{self, Runtime}; + +use std::sync::Arc; + +#[test] +fn blocking_shutdown() { + loom::model(|| { + let v = Arc::new(()); + + let rt = mk_runtime(1); + { + let _enter = rt.enter(); + for _ in 0..2 { + let v = v.clone(); + crate::task::spawn_blocking(move || { + assert!(1 < Arc::strong_count(&v)); + }); + } + } + + drop(rt); + assert_eq!(1, Arc::strong_count(&v)); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_always_run() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + + let (tx, rx) = loom_oneshot::channel(); + let _enter = rt.enter(); + runtime::spawn_blocking(|| {}); + runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }) + .unwrap(); + + drop(rt); + + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + let handle = rt.handle().clone(); + + // Drop the runtime in a different thread + { + loom::thread::spawn(move || { + drop(rt); + }); + } + + let _enter = handle.enter(); + let (tx, rx) = loom_oneshot::channel(); + let handle = runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }); + + // handle.is_some() means that `spawn_mandatory_blocking` + // promised us to run the blocking task + if handle.is_some() { + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + } + }); +} + +fn mk_runtime(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs new file mode 100644 index 0000000000..e87ddb0140 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs @@ -0,0 +1,82 @@ +use crate::runtime::Builder; +use crate::task::JoinSet; + +#[test] +fn test_join_set() { + loom::model(|| { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let mut set = JoinSet::new(); + + rt.block_on(async { + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + }); + + drop(set); + drop(rt); + }); +} + +#[test] +fn abort_all_during_completion() { + use std::sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }; + + // These booleans assert that at least one execution had the task complete first, and that at + // least one execution had the task be cancelled before it completed. + let complete_happened = Arc::new(AtomicBool::new(false)); + let cancel_happened = Arc::new(AtomicBool::new(false)); + + { + let complete_happened = complete_happened.clone(); + let cancel_happened = cancel_happened.clone(); + loom::model(move || { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + + let mut set = JoinSet::new(); + + rt.block_on(async { + set.spawn(async { () }); + set.abort_all(); + + match set.join_one().await { + Ok(Some(())) => complete_happened.store(true, SeqCst), + Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst), + Err(err) => panic!("fail: {}", err), + Ok(None) => { + unreachable!("Aborting the task does not remove it from the JoinSet.") + } + } + + assert!(matches!(set.join_one().await, Ok(None))); + }); + + drop(set); + drop(rt); + }); + } + + assert!(complete_happened.load(SeqCst)); + assert!(cancel_happened.load(SeqCst)); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_local.rs b/third_party/rust/tokio/src/runtime/tests/loom_local.rs new file mode 100644 index 0000000000..d9a07a45f0 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs new file mode 100644 index 0000000000..87eb638642 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs @@ -0,0 +1,48 @@ +use crate::loom::sync::{Arc, Mutex}; +use loom::sync::Notify; + +pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) { + let inner = Arc::new(Inner { + notify: Notify::new(), + value: Mutex::new(None), + }); + + let tx = Sender { + inner: inner.clone(), + }; + let rx = Receiver { inner }; + + (tx, rx) +} + +pub(crate) struct Sender<T> { + inner: Arc<Inner<T>>, +} + +pub(crate) struct Receiver<T> { + inner: Arc<Inner<T>>, +} + +struct Inner<T> { + notify: Notify, + value: Mutex<Option<T>>, +} + +impl<T> Sender<T> { + pub(crate) fn send(self, value: T) { + *self.inner.value.lock() = Some(value); + self.inner.notify.notify(); + } +} + +impl<T> Receiver<T> { + pub(crate) fn recv(self) -> T { + loop { + if let Some(v) = self.inner.value.lock().take() { + return v; + } + + self.inner.notify.wait(); + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_pool.rs b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs new file mode 100644 index 0000000000..b3ecd43124 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs @@ -0,0 +1,430 @@ +/// Full runtime loom tests. These are heavy tests and take significant time to +/// run on CI. +/// +/// Use `LOOM_MAX_PREEMPTIONS=1` to do a "quick" run as a smoke test. +/// +/// In order to speed up the C +use crate::future::poll_fn; +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::{self, Runtime}; +use crate::{spawn, task}; +use tokio_test::assert_ok; + +use loom::sync::atomic::{AtomicBool, AtomicUsize}; +use loom::sync::Arc; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Relaxed, SeqCst}; +use std::task::{Context, Poll}; + +mod atomic_take { + use loom::sync::atomic::AtomicBool; + use std::mem::MaybeUninit; + use std::sync::atomic::Ordering::SeqCst; + + pub(super) struct AtomicTake<T> { + inner: MaybeUninit<T>, + taken: AtomicBool, + } + + impl<T> AtomicTake<T> { + pub(super) fn new(value: T) -> Self { + Self { + inner: MaybeUninit::new(value), + taken: AtomicBool::new(false), + } + } + + pub(super) fn take(&self) -> Option<T> { + // safety: Only one thread will see the boolean change from false + // to true, so that thread is able to take the value. + match self.taken.fetch_or(true, SeqCst) { + false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) }, + true => None, + } + } + } + + impl<T> Drop for AtomicTake<T> { + fn drop(&mut self) { + drop(self.take()); + } + } +} + +#[derive(Clone)] +struct AtomicOneshot<T> { + value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>, +} +impl<T> AtomicOneshot<T> { + fn new(sender: oneshot::Sender<T>) -> Self { + Self { + value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)), + } + } + + fn assert_send(&self, value: T) { + self.value.take().unwrap().send(value); + } +} + +/// Tests are divided into groups to make the runs faster on CI. +mod group_a { + use super::*; + + #[test] + fn racy_shutdown() { + loom::model(|| { + let pool = mk_pool(1); + + // here's the case we want to exercise: + // + // a worker that still has tasks in its local queue gets sent to the blocking pool (due to + // block_in_place). the blocking pool is shut down, so drops the worker. the worker's + // shutdown method never gets run. + // + // we do this by spawning two tasks on one worker, the first of which does block_in_place, + // and then immediately drop the pool. + + pool.spawn(track(async { + crate::task::block_in_place(|| {}); + })); + pool.spawn(track(async {})); + drop(pool); + }); + } + + #[test] + fn pool_multi_spawn() { + loom::model(|| { + let pool = mk_pool(2); + let c1 = Arc::new(AtomicUsize::new(0)); + + let (tx, rx) = oneshot::channel(); + let tx1 = AtomicOneshot::new(tx); + + // Spawn a task + let c2 = c1.clone(); + let tx2 = tx1.clone(); + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c1.fetch_add(1, Relaxed) { + tx1.assert_send(()); + } + })); + })); + + // Spawn a second task + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c2.fetch_add(1, Relaxed) { + tx2.assert_send(()); + } + })); + })); + + rx.recv(); + }); + } + + fn only_blocking_inner(first_pending: bool) { + loom::model(move || { + let pool = mk_pool(1); + let (block_tx, block_rx) = oneshot::channel(); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + block_rx.recv(); + drop(pool); + }); + } + + #[test] + fn only_blocking_without_pending() { + only_blocking_inner(false) + } + + #[test] + fn only_blocking_with_pending() { + only_blocking_inner(true) + } +} + +mod group_b { + use super::*; + + fn blocking_and_regular_inner(first_pending: bool) { + const NUM: usize = 3; + loom::model(move || { + let pool = mk_pool(1); + let cnt = Arc::new(AtomicUsize::new(0)); + + let (block_tx, block_rx) = oneshot::channel(); + let (done_tx, done_rx) = oneshot::channel(); + let done_tx = AtomicOneshot::new(done_tx); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + for _ in 0..NUM { + let cnt = cnt.clone(); + let done_tx = done_tx.clone(); + + pool.spawn(track(async move { + if NUM == cnt.fetch_add(1, Relaxed) + 1 { + done_tx.assert_send(()); + } + })); + } + + done_rx.recv(); + block_rx.recv(); + + drop(pool); + }); + } + + #[test] + fn blocking_and_regular() { + blocking_and_regular_inner(false); + } + + #[test] + fn blocking_and_regular_with_pending() { + blocking_and_regular_inner(true); + } + + #[test] + fn join_output() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async { + let t = crate::spawn(track(async { "hello" })); + + let out = assert_ok!(t.await); + assert_eq!("hello", out.into_inner()); + }); + }); + } + + #[test] + fn poll_drop_handle_then_drop() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async move { + let mut t = crate::spawn(track(async { "hello" })); + + poll_fn(|cx| { + let _ = Pin::new(&mut t).poll(cx); + Poll::Ready(()) + }) + .await; + }); + }) + } + + #[test] + fn complete_block_on_under_load() { + loom::model(|| { + let pool = mk_pool(1); + + pool.block_on(async { + // Trigger a re-schedule + crate::spawn(track(async { + for _ in 0..2 { + task::yield_now().await; + } + })); + + gated2(true).await + }); + }); + } + + #[test] + fn shutdown_with_notification() { + use crate::sync::oneshot; + + loom::model(|| { + let rt = mk_pool(2); + let (done_tx, done_rx) = oneshot::channel::<()>(); + + rt.spawn(track(async move { + let (tx, rx) = oneshot::channel::<()>(); + + crate::spawn(async move { + crate::task::spawn_blocking(move || { + let _ = tx.send(()); + }); + + let _ = done_rx.await; + }); + + let _ = rx.await; + + let _ = done_tx.send(()); + })); + }); + } +} + +mod group_c { + use super::*; + + #[test] + fn pool_shutdown() { + loom::model(|| { + let pool = mk_pool(2); + + pool.spawn(track(async move { + gated2(true).await; + })); + + pool.spawn(track(async move { + gated2(false).await; + })); + + drop(pool); + }); + } +} + +mod group_d { + use super::*; + + #[test] + fn pool_multi_notify() { + loom::model(|| { + let pool = mk_pool(2); + + let c1 = Arc::new(AtomicUsize::new(0)); + + let (done_tx, done_rx) = oneshot::channel(); + let done_tx1 = AtomicOneshot::new(done_tx); + let done_tx2 = done_tx1.clone(); + + // Spawn a task + let c2 = c1.clone(); + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c1.fetch_add(1, Relaxed) { + done_tx1.assert_send(()); + } + })); + + // Spawn a second task + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c2.fetch_add(1, Relaxed) { + done_tx2.assert_send(()); + } + })); + + done_rx.recv(); + }); + } +} + +fn mk_pool(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} + +fn gated() -> impl Future<Output = &'static str> { + gated2(false) +} + +fn gated2(thread: bool) -> impl Future<Output = &'static str> { + use loom::thread; + use std::sync::Arc; + + let gate = Arc::new(AtomicBool::new(false)); + let mut fired = false; + + poll_fn(move |cx| { + if !fired { + let gate = gate.clone(); + let waker = cx.waker().clone(); + + if thread { + thread::spawn(move || { + gate.store(true, SeqCst); + waker.wake_by_ref(); + }); + } else { + spawn(track(async move { + gate.store(true, SeqCst); + waker.wake_by_ref(); + })); + } + + fired = true; + + return Poll::Pending; + } + + if gate.load(SeqCst) { + Poll::Ready("hello world") + } else { + Poll::Pending + } + }) +} + +fn track<T: Future>(f: T) -> Track<T> { + Track { + inner: f, + arc: Arc::new(()), + } +} + +pin_project! { + struct Track<T> { + #[pin] + inner: T, + // Arc is used to hook into loom's leak tracking. + arc: Arc<()>, + } +} + +impl<T> Track<T> { + fn into_inner(self) -> T { + self.inner + } +} + +impl<T: Future> Future for Track<T> { + type Output = Track<T::Output>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + Poll::Ready(Track { + inner: ready!(me.inner.poll(cx)), + arc: me.arc.clone(), + }) + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_queue.rs b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs new file mode 100644 index 0000000000..b5f78d7ebe --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs @@ -0,0 +1,208 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::Inject; +use crate::runtime::{queue, MetricsBatch}; + +use loom::thread; + +#[test] +fn basic() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..3 { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + } + + n + }); + + let mut n = 0; + + for _ in 0..2 { + for _ in 0..2 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + if local.pop().is_some() { + n += 1; + } + + // Push another task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + while local.pop().is_some() { + n += 1; + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(6, n); + }); +} + +#[test] +fn steal_overflow() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + n + }); + + let mut n = 0; + + // push a task, pop a task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if local.pop().is_some() { + n += 1; + } + + for _ in 0..6 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + n += th.join().unwrap(); + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + assert_eq!(7, n); + }); +} + +#[test] +fn multi_stealer() { + const NUM_TASKS: usize = 5; + + fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + + if steal.steal_into(&mut local, &mut metrics).is_none() { + return 0; + } + + let mut n = 1; + + while local.pop().is_some() { + n += 1; + } + + n + } + + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + // Push work + for _ in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + let th1 = { + let steal = steal.clone(); + thread::spawn(move || steal_tasks(steal)) + }; + + let th2 = thread::spawn(move || steal_tasks(steal)); + + let mut n = 0; + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + n += th1.join().unwrap(); + n += th2.join().unwrap(); + + assert_eq!(n, NUM_TASKS); + }); +} + +#[test] +fn chained_steal() { + loom::model(|| { + let mut metrics = MetricsBatch::new(); + let (s1, mut l1) = queue::local(); + let (s2, mut l2) = queue::local(); + let inject = Inject::new(); + + // Load up some tasks + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + l1.push_back(task, &inject, &mut metrics); + + let (task, _) = super::unowned(async {}); + l2.push_back(task, &inject, &mut metrics); + } + + // Spawn a task to steal from **our** queue + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + s1.steal_into(&mut local, &mut metrics); + + while local.pop().is_some() {} + }); + + // Drain our tasks, then attempt to steal + while l1.pop().is_some() {} + + s2.steal_into(&mut l1, &mut metrics); + + th.join().unwrap(); + + while l1.pop().is_some() {} + while l2.pop().is_some() {} + while inject.pop().is_some() {} + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs new file mode 100644 index 0000000000..6fbc4bfded --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs @@ -0,0 +1,28 @@ +use crate::runtime::{Builder, Handle}; + +#[test] +fn join_handle_cancel_on_shutdown() { + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + builder.check(|| { + use futures::future::FutureExt; + + let rt = Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(); + + let handle = rt.block_on(async move { Handle::current() }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/mod.rs b/third_party/rust/tokio/src/runtime/tests/mod.rs new file mode 100644 index 0000000000..4b49698a86 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/mod.rs @@ -0,0 +1,50 @@ +use self::unowned_wrapper::unowned; + +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + use tracing::Instrument; + let span = tracing::trace_span!("test_span"); + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } +} + +cfg_loom! { + mod loom_basic_scheduler; + mod loom_blocking; + mod loom_local; + mod loom_oneshot; + mod loom_pool; + mod loom_queue; + mod loom_shutdown_join; + mod loom_join_set; +} + +cfg_not_loom! { + mod queue; + + #[cfg(not(miri))] + mod task_combinations; + + #[cfg(miri)] + mod task; +} diff --git a/third_party/rust/tokio/src/runtime/tests/queue.rs b/third_party/rust/tokio/src/runtime/tests/queue.rs new file mode 100644 index 0000000000..0fd1e0c6d9 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/queue.rs @@ -0,0 +1,248 @@ +use crate::runtime::queue; +use crate::runtime::task::{self, Inject, Schedule, Task}; +use crate::runtime::MetricsBatch; + +use std::thread; +use std::time::Duration; + +#[allow(unused)] +macro_rules! assert_metrics { + ($metrics:ident, $field:ident == $v:expr) => {{ + use crate::runtime::WorkerMetrics; + use std::sync::atomic::Ordering::Relaxed; + + let worker = WorkerMetrics::new(); + $metrics.submit(&worker); + + let expect = $v; + let actual = worker.$field.load(Relaxed); + + assert!(actual == expect, "expect = {}; actual = {}", expect, actual) + }}; +} + +#[test] +fn fits_256() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..256 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 0); + } + + assert!(inject.pop().is_none()); + + while local.pop().is_some() {} +} + +#[test] +fn overflow() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..257 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 1); + } + + let mut n = 0; + + while inject.pop().is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + assert_eq!(n, 257); +} + +#[test] +fn steal_batch() { + let mut metrics = MetricsBatch::new(); + + let (steal1, mut local1) = queue::local(); + let (_, mut local2) = queue::local(); + let inject = Inject::new(); + + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + local1.push_back(task, &inject, &mut metrics); + } + + assert!(steal1.steal_into(&mut local2, &mut metrics).is_some()); + + cfg_metrics! { + assert_metrics!(metrics, steal_count == 2); + } + + for _ in 0..1 { + assert!(local2.pop().is_some()); + } + + assert!(local2.pop().is_none()); + + for _ in 0..2 { + assert!(local1.pop().is_some()); + } + + assert!(local1.pop().is_none()); +} + +const fn normal_or_miri(normal: usize, miri: usize) -> usize { + if cfg!(miri) { + miri + } else { + normal + } +} + +#[test] +fn stress1() { + const NUM_ITER: usize = 1; + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + const NUM_LOCAL: usize = normal_or_miri(1_000, 10); + const NUM_PUSH: usize = normal_or_miri(500, 10); + const NUM_POP: usize = normal_or_miri(250, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::yield_now(); + } + + cfg_metrics! { + assert_metrics!(metrics, steal_count == n as _); + } + + n + }); + + let mut n = 0; + + for _ in 0..NUM_LOCAL { + for _ in 0..NUM_PUSH { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + for _ in 0..NUM_POP { + if local.pop().is_some() { + n += 1; + } else { + break; + } + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(n, NUM_LOCAL * NUM_PUSH); + } +} + +#[test] +fn stress2() { + const NUM_ITER: usize = 1; + const NUM_TASKS: usize = normal_or_miri(1_000_000, 50); + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut stats = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut stats).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::sleep(Duration::from_micros(10)); + } + + n + }); + + let mut num_pop = 0; + + for i in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if i % 128 == 0 && local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + } + + num_pop += th.join().unwrap(); + + while local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + + assert_eq!(num_pop, NUM_TASKS); + } +} + +struct Runtime; + +impl Schedule for Runtime { + fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { + None + } + + fn schedule(&self, _task: task::Notified<Self>) { + unreachable!(); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task.rs b/third_party/rust/tokio/src/runtime/tests/task.rs new file mode 100644 index 0000000000..04e1b56e77 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task.rs @@ -0,0 +1,288 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::util::TryLock; + +use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +struct AssertDropHandle { + is_dropped: Arc<AtomicBool>, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc<AtomicBool>, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + +#[test] +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works +#[test] +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); +} + +#[test] +fn unowned_poll() { + let (task, _) = unowned(async {}, NoopSchedule); + task.run(); +} + +#[test] +fn schedule() { + with(|rt| { + rt.spawn(async { + crate::task::yield_now().await; + }); + + assert_eq!(2, rt.tick()); + rt.shutdown(); + }) +} + +#[test] +fn shutdown() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + + rt.shutdown(); + }) +} + +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + +fn with(f: impl FnOnce(Runtime)) { + struct Reset; + + impl Drop for Reset { + fn drop(&mut self) { + let _rt = CURRENT.try_lock().unwrap().take(); + } + } + + let _reset = Reset; + + let rt = Runtime(Arc::new(Inner { + owned: OwnedTasks::new(), + core: TryLock::new(Core { + queue: VecDeque::new(), + }), + })); + + *CURRENT.try_lock().unwrap() = Some(rt.clone()); + f(rt) +} + +#[derive(Clone)] +struct Runtime(Arc<Inner>); + +struct Inner { + core: TryLock<Core>, + owned: OwnedTasks<Runtime>, +} + +struct Core { + queue: VecDeque<task::Notified<Runtime>>, +} + +static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); + +impl Runtime { + fn spawn<T>(&self, future: T) -> JoinHandle<T::Output> + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + + fn tick(&self) -> usize { + self.tick_max(usize::MAX) + } + + fn tick_max(&self, max: usize) -> usize { + let mut n = 0; + + while !self.is_empty() && n < max { + let task = self.next_task(); + n += 1; + let task = self.0.owned.assert_owner(task); + task.run(); + } + + n + } + + fn is_empty(&self) -> bool { + self.0.core.try_lock().unwrap().queue.is_empty() + } + + fn next_task(&self) -> task::Notified<Runtime> { + self.0.core.try_lock().unwrap().queue.pop_front().unwrap() + } + + fn shutdown(&self) { + let mut core = self.0.core.try_lock().unwrap(); + + self.0.owned.close_and_shutdown_all(); + + while let Some(task) = core.queue.pop_back() { + drop(task); + } + + drop(core); + + assert!(self.0.owned.is_empty()); + } +} + +impl Schedule for Runtime { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + self.0.owned.remove(task) + } + + fn schedule(&self, task: task::Notified<Self>) { + self.0.core.try_lock().unwrap().queue.push_back(task); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task_combinations.rs b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs new file mode 100644 index 0000000000..76ce2330c2 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option<crate::task::LocalSet>, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on<T>(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn<T>(&self, task: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option<oneshot::Sender<()>>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper<F> { + inner: F, + on_drop: Option<oneshot::Sender<()>>, + panic_on_drop: bool, + } + impl<F: Future> Future for FutWrapper<F> { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl<F> Drop for FutWrapper<F> { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option<oneshot::Sender<()>>, + wait_complete: Option<oneshot::Receiver<()>>, + on_output_drop: Option<oneshot::Sender<()>>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/idle.rs b/third_party/rust/tokio/src/runtime/thread_pool/idle.rs new file mode 100644 index 0000000000..a57bf6a0b1 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/idle.rs @@ -0,0 +1,226 @@ +//! Coordinates idling workers + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; + +use std::fmt; +use std::sync::atomic::Ordering::{self, SeqCst}; + +pub(super) struct Idle { + /// Tracks both the number of searching workers and the number of unparked + /// workers. + /// + /// Used as a fast-path to avoid acquiring the lock when needed. + state: AtomicUsize, + + /// Sleeping workers + sleepers: Mutex<Vec<usize>>, + + /// Total number of workers. + num_workers: usize, +} + +const UNPARK_SHIFT: usize = 16; +const UNPARK_MASK: usize = !SEARCH_MASK; +const SEARCH_MASK: usize = (1 << UNPARK_SHIFT) - 1; + +#[derive(Copy, Clone)] +struct State(usize); + +impl Idle { + pub(super) fn new(num_workers: usize) -> Idle { + let init = State::new(num_workers); + + Idle { + state: AtomicUsize::new(init.into()), + sleepers: Mutex::new(Vec::with_capacity(num_workers)), + num_workers, + } + } + + /// If there are no workers actively searching, returns the index of a + /// worker currently sleeping. + pub(super) fn worker_to_notify(&self) -> Option<usize> { + // If at least one worker is spinning, work being notified will + // eventually be found. A searching thread will find **some** work and + // notify another worker, eventually leading to our work being found. + // + // For this to happen, this load must happen before the thread + // transitioning `num_searching` to zero. Acquire / Release does not + // provide sufficient guarantees, so this load is done with `SeqCst` and + // will pair with the `fetch_sub(1)` when transitioning out of + // searching. + if !self.notify_should_wakeup() { + return None; + } + + // Acquire the lock + let mut sleepers = self.sleepers.lock(); + + // Check again, now that the lock is acquired + if !self.notify_should_wakeup() { + return None; + } + + // A worker should be woken up, atomically increment the number of + // searching workers as well as the number of unparked workers. + State::unpark_one(&self.state, 1); + + // Get the worker to unpark + let ret = sleepers.pop(); + debug_assert!(ret.is_some()); + + ret + } + + /// Returns `true` if the worker needs to do a final check for submitted + /// work. + pub(super) fn transition_worker_to_parked(&self, worker: usize, is_searching: bool) -> bool { + // Acquire the lock + let mut sleepers = self.sleepers.lock(); + + // Decrement the number of unparked threads + let ret = State::dec_num_unparked(&self.state, is_searching); + + // Track the sleeping worker + sleepers.push(worker); + + ret + } + + pub(super) fn transition_worker_to_searching(&self) -> bool { + let state = State::load(&self.state, SeqCst); + if 2 * state.num_searching() >= self.num_workers { + return false; + } + + // It is possible for this routine to allow more than 50% of the workers + // to search. That is OK. Limiting searchers is only an optimization to + // prevent too much contention. + State::inc_num_searching(&self.state, SeqCst); + true + } + + /// A lightweight transition from searching -> running. + /// + /// Returns `true` if this is the final searching worker. The caller + /// **must** notify a new worker. + pub(super) fn transition_worker_from_searching(&self) -> bool { + State::dec_num_searching(&self.state) + } + + /// Unpark a specific worker. This happens if tasks are submitted from + /// within the worker's park routine. + /// + /// Returns `true` if the worker was parked before calling the method. + pub(super) fn unpark_worker_by_id(&self, worker_id: usize) -> bool { + let mut sleepers = self.sleepers.lock(); + + for index in 0..sleepers.len() { + if sleepers[index] == worker_id { + sleepers.swap_remove(index); + + // Update the state accordingly while the lock is held. + State::unpark_one(&self.state, 0); + + return true; + } + } + + false + } + + /// Returns `true` if `worker_id` is contained in the sleep set. + pub(super) fn is_parked(&self, worker_id: usize) -> bool { + let sleepers = self.sleepers.lock(); + sleepers.contains(&worker_id) + } + + fn notify_should_wakeup(&self) -> bool { + let state = State(self.state.fetch_add(0, SeqCst)); + state.num_searching() == 0 && state.num_unparked() < self.num_workers + } +} + +impl State { + fn new(num_workers: usize) -> State { + // All workers start in the unparked state + let ret = State(num_workers << UNPARK_SHIFT); + debug_assert_eq!(num_workers, ret.num_unparked()); + debug_assert_eq!(0, ret.num_searching()); + ret + } + + fn load(cell: &AtomicUsize, ordering: Ordering) -> State { + State(cell.load(ordering)) + } + + fn unpark_one(cell: &AtomicUsize, num_searching: usize) { + cell.fetch_add(num_searching | (1 << UNPARK_SHIFT), SeqCst); + } + + fn inc_num_searching(cell: &AtomicUsize, ordering: Ordering) { + cell.fetch_add(1, ordering); + } + + /// Returns `true` if this is the final searching worker + fn dec_num_searching(cell: &AtomicUsize) -> bool { + let state = State(cell.fetch_sub(1, SeqCst)); + state.num_searching() == 1 + } + + /// Track a sleeping worker + /// + /// Returns `true` if this is the final searching worker. + fn dec_num_unparked(cell: &AtomicUsize, is_searching: bool) -> bool { + let mut dec = 1 << UNPARK_SHIFT; + + if is_searching { + dec += 1; + } + + let prev = State(cell.fetch_sub(dec, SeqCst)); + is_searching && prev.num_searching() == 1 + } + + /// Number of workers currently searching + fn num_searching(self) -> usize { + self.0 & SEARCH_MASK + } + + /// Number of workers currently unparked + fn num_unparked(self) -> usize { + (self.0 & UNPARK_MASK) >> UNPARK_SHIFT + } +} + +impl From<usize> for State { + fn from(src: usize) -> State { + State(src) + } +} + +impl From<State> for usize { + fn from(src: State) -> usize { + src.0 + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("worker::State") + .field("num_unparked", &self.num_unparked()) + .field("num_searching", &self.num_searching()) + .finish() + } +} + +#[test] +fn test_state() { + assert_eq!(0, UNPARK_MASK & SEARCH_MASK); + assert_eq!(0, !(UNPARK_MASK | SEARCH_MASK)); + + let state = State::new(10); + assert_eq!(10, state.num_unparked()); + assert_eq!(0, state.num_searching()); +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/mod.rs b/third_party/rust/tokio/src/runtime/thread_pool/mod.rs new file mode 100644 index 0000000000..d3f46517cb --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/mod.rs @@ -0,0 +1,136 @@ +//! Threadpool + +mod idle; +use self::idle::Idle; + +mod worker; +pub(crate) use worker::Launch; + +pub(crate) use worker::block_in_place; + +use crate::loom::sync::Arc; +use crate::runtime::task::JoinHandle; +use crate::runtime::{Callback, Parker}; + +use std::fmt; +use std::future::Future; + +/// Work-stealing based thread pool for executing futures. +pub(crate) struct ThreadPool { + spawner: Spawner, +} + +/// Submits futures to the associated thread pool for execution. +/// +/// A `Spawner` instance is a handle to a single thread pool that allows the owner +/// of the handle to spawn futures onto the thread pool. +/// +/// The `Spawner` handle is *only* used for spawning new futures. It does not +/// impact the lifecycle of the thread pool in any way. The thread pool may +/// shut down while there are outstanding `Spawner` instances. +/// +/// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. +/// +/// [`ThreadPool::spawner`]: method@ThreadPool::spawner +#[derive(Clone)] +pub(crate) struct Spawner { + shared: Arc<worker::Shared>, +} + +// ===== impl ThreadPool ===== + +impl ThreadPool { + pub(crate) fn new( + size: usize, + parker: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> (ThreadPool, Launch) { + let (shared, launch) = worker::create(size, parker, before_park, after_unpark); + let spawner = Spawner { shared }; + let thread_pool = ThreadPool { spawner }; + + (thread_pool, launch) + } + + /// Returns reference to `Spawner`. + /// + /// The `Spawner` handle can be cloned and enables spawning tasks from other + /// threads. + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + /// Blocks the current thread waiting for the future to complete. + /// + /// The future will execute on the current thread, but all spawned tasks + /// will be executed on the thread pool. + pub(crate) fn block_on<F>(&self, future: F) -> F::Output + where + F: Future, + { + let mut enter = crate::runtime::enter(true); + enter.block_on(future).expect("failed to park thread") + } +} + +impl fmt::Debug for ThreadPool { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("ThreadPool").finish() + } +} + +impl Drop for ThreadPool { + fn drop(&mut self) { + self.spawner.shutdown(); + } +} + +// ==== impl Spawner ===== + +impl Spawner { + /// Spawns a future onto the thread pool + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + worker::Shared::bind_new_task(&self.shared, future) + } + + pub(crate) fn shutdown(&mut self) { + self.shared.close(); + } +} + +cfg_metrics! { + use crate::runtime::{SchedulerMetrics, WorkerMetrics}; + + impl Spawner { + pub(crate) fn num_workers(&self) -> usize { + self.shared.worker_metrics.len() + } + + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + &self.shared.scheduler_metrics + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + &self.shared.worker_metrics[worker] + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + self.shared.injection_queue_depth() + } + + pub(crate) fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.shared.worker_local_queue_depth(worker) + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Spawner").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/worker.rs b/third_party/rust/tokio/src/runtime/thread_pool/worker.rs new file mode 100644 index 0000000000..7e4989701e --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/worker.rs @@ -0,0 +1,848 @@ +//! A scheduler is initialized with a fixed number of workers. Each worker is +//! driven by a thread. Each worker has a "core" which contains data such as the +//! run queue and other state. When `block_in_place` is called, the worker's +//! "core" is handed off to a new thread allowing the scheduler to continue to +//! make progress while the originating thread blocks. +//! +//! # Shutdown +//! +//! Shutting down the runtime involves the following steps: +//! +//! 1. The Shared::close method is called. This closes the inject queue and +//! OwnedTasks instance and wakes up all worker threads. +//! +//! 2. Each worker thread observes the close signal next time it runs +//! Core::maintenance by checking whether the inject queue is closed. +//! The Core::is_shutdown flag is set to true. +//! +//! 3. The worker thread calls `pre_shutdown` in parallel. Here, the worker +//! will keep removing tasks from OwnedTasks until it is empty. No new +//! tasks can be pushed to the OwnedTasks during or after this step as it +//! was closed in step 1. +//! +//! 5. The workers call Shared::shutdown to enter the single-threaded phase of +//! shutdown. These calls will push their core to Shared::shutdown_cores, +//! and the last thread to push its core will finish the shutdown procedure. +//! +//! 6. The local run queue of each core is emptied, then the inject queue is +//! emptied. +//! +//! At this point, shutdown has completed. It is not possible for any of the +//! collections to contain any tasks at this point, as each collection was +//! closed first, then emptied afterwards. +//! +//! ## Spawns during shutdown +//! +//! When spawning tasks during shutdown, there are two cases: +//! +//! * The spawner observes the OwnedTasks being open, and the inject queue is +//! closed. +//! * The spawner observes the OwnedTasks being closed and doesn't check the +//! inject queue. +//! +//! The first case can only happen if the OwnedTasks::bind call happens before +//! or during step 1 of shutdown. In this case, the runtime will clean up the +//! task in step 3 of shutdown. +//! +//! In the latter case, the task was not spawned and the task is immediately +//! cancelled by the spawner. +//! +//! The correctness of shutdown requires both the inject queue and OwnedTasks +//! collection to have a closed bit. With a close bit on only the inject queue, +//! spawning could run in to a situation where a task is successfully bound long +//! after the runtime has shut down. With a close bit on only the OwnedTasks, +//! the first spawning situation could result in the notification being pushed +//! to the inject queue after step 6 of shutdown, which would leave a task in +//! the inject queue indefinitely. This would be a ref-count cycle and a memory +//! leak. + +use crate::coop; +use crate::future::Future; +use crate::loom::rand::seed; +use crate::loom::sync::{Arc, Mutex}; +use crate::park::{Park, Unpark}; +use crate::runtime; +use crate::runtime::enter::EnterContext; +use crate::runtime::park::{Parker, Unparker}; +use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; +use crate::runtime::thread_pool::Idle; +use crate::runtime::{queue, task, Callback, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +use crate::util::atomic_cell::AtomicCell; +use crate::util::FastRand; + +use std::cell::RefCell; +use std::time::Duration; + +/// A scheduler worker +pub(super) struct Worker { + /// Reference to shared state + shared: Arc<Shared>, + + /// Index holding this worker's remote state + index: usize, + + /// Used to hand-off a worker's core to another thread. + core: AtomicCell<Core>, +} + +/// Core data +struct Core { + /// Used to schedule bookkeeping tasks every so often. + tick: u8, + + /// When a task is scheduled from a worker, it is stored in this slot. The + /// worker will check this slot for a task **before** checking the run + /// queue. This effectively results in the **last** scheduled task to be run + /// next (LIFO). This is an optimization for message passing patterns and + /// helps to reduce latency. + lifo_slot: Option<Notified>, + + /// The worker-local run queue. + run_queue: queue::Local<Arc<Shared>>, + + /// True if the worker is currently searching for more work. Searching + /// involves attempting to steal from other workers. + is_searching: bool, + + /// True if the scheduler is being shutdown + is_shutdown: bool, + + /// Parker + /// + /// Stored in an `Option` as the parker is added / removed to make the + /// borrow checker happy. + park: Option<Parker>, + + /// Batching metrics so they can be submitted to RuntimeMetrics. + metrics: MetricsBatch, + + /// Fast random number generator. + rand: FastRand, +} + +/// State shared across all workers +pub(super) struct Shared { + /// Per-worker remote state. All other workers have access to this and is + /// how they communicate between each other. + remotes: Box<[Remote]>, + + /// Submits work to the scheduler while **not** currently on a worker thread. + inject: Inject<Arc<Shared>>, + + /// Coordinates idle workers + idle: Idle, + + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + + /// Cores that have observed the shutdown signal + /// + /// The core is **not** placed back in the worker to avoid it from being + /// stolen by a thread that was spawned as part of `block_in_place`. + #[allow(clippy::vec_box)] // we're moving an already-boxed value + shutdown_cores: Mutex<Vec<Box<Core>>>, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Collects metrics from the runtime. + pub(super) scheduler_metrics: SchedulerMetrics, + + pub(super) worker_metrics: Box<[WorkerMetrics]>, +} + +/// Used to communicate with a worker from other threads. +struct Remote { + /// Steals tasks from this worker. + steal: queue::Steal<Arc<Shared>>, + + /// Unparks the associated worker thread + unpark: Unparker, +} + +/// Thread-local context +struct Context { + /// Worker + worker: Arc<Worker>, + + /// Core data + core: RefCell<Option<Box<Core>>>, +} + +/// Starts the workers +pub(crate) struct Launch(Vec<Arc<Worker>>); + +/// Running a task may consume the core. If the core is still available when +/// running the task completes, it is returned. Otherwise, the worker will need +/// to stop processing. +type RunResult = Result<Box<Core>, ()>; + +/// A task handle +type Task = task::Task<Arc<Shared>>; + +/// A notified task handle +type Notified = task::Notified<Arc<Shared>>; + +// Tracks thread-local state +scoped_thread_local!(static CURRENT: Context); + +pub(super) fn create( + size: usize, + park: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, +) -> (Arc<Shared>, Launch) { + let mut cores = vec![]; + let mut remotes = vec![]; + let mut worker_metrics = vec![]; + + // Create the local queues + for _ in 0..size { + let (steal, run_queue) = queue::local(); + + let park = park.clone(); + let unpark = park.unpark(); + + cores.push(Box::new(Core { + tick: 0, + lifo_slot: None, + run_queue, + is_searching: false, + is_shutdown: false, + park: Some(park), + metrics: MetricsBatch::new(), + rand: FastRand::new(seed()), + })); + + remotes.push(Remote { steal, unpark }); + worker_metrics.push(WorkerMetrics::new()); + } + + let shared = Arc::new(Shared { + remotes: remotes.into_boxed_slice(), + inject: Inject::new(), + idle: Idle::new(size), + owned: OwnedTasks::new(), + shutdown_cores: Mutex::new(vec![]), + before_park, + after_unpark, + scheduler_metrics: SchedulerMetrics::new(), + worker_metrics: worker_metrics.into_boxed_slice(), + }); + + let mut launch = Launch(vec![]); + + for (index, core) in cores.drain(..).enumerate() { + launch.0.push(Arc::new(Worker { + shared: shared.clone(), + index, + core: AtomicCell::new(Some(core)), + })); + } + + (shared, launch) +} + +pub(crate) fn block_in_place<F, R>(f: F) -> R +where + F: FnOnce() -> R, +{ + // Try to steal the worker core back + struct Reset(coop::Budget); + + impl Drop for Reset { + fn drop(&mut self) { + CURRENT.with(|maybe_cx| { + if let Some(cx) = maybe_cx { + let core = cx.worker.core.take(); + let mut cx_core = cx.core.borrow_mut(); + assert!(cx_core.is_none()); + *cx_core = core; + + // Reset the task budget as we are re-entering the + // runtime. + coop::set(self.0); + } + }); + } + } + + let mut had_entered = false; + + CURRENT.with(|maybe_cx| { + match (crate::runtime::enter::context(), maybe_cx.is_some()) { + (EnterContext::Entered { .. }, true) => { + // We are on a thread pool runtime thread, so we just need to + // set up blocking. + had_entered = true; + } + (EnterContext::Entered { allow_blocking }, false) => { + // We are on an executor, but _not_ on the thread pool. That is + // _only_ okay if we are in a thread pool runtime's block_on + // method: + if allow_blocking { + had_entered = true; + return; + } else { + // This probably means we are on the basic_scheduler or in a + // LocalSet, where it is _not_ okay to block. + panic!("can call blocking only when running on the multi-threaded runtime"); + } + } + (EnterContext::NotEntered, true) => { + // This is a nested call to block_in_place (we already exited). + // All the necessary setup has already been done. + return; + } + (EnterContext::NotEntered, false) => { + // We are outside of the tokio runtime, so blocking is fine. + // We can also skip all of the thread pool blocking setup steps. + return; + } + } + + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); + + // Get the worker core. If none is set, then blocking is fine! + let core = match cx.core.borrow_mut().take() { + Some(core) => core, + None => return, + }; + + // The parker should be set here + assert!(core.park.is_some()); + + // In order to block, the core must be sent to another thread for + // execution. + // + // First, move the core back into the worker's shared core slot. + cx.worker.core.set(core); + + // Next, clone the worker handle and send it to a new thread for + // processing. + // + // Once the blocking task is done executing, we will attempt to + // steal the core back. + let worker = cx.worker.clone(); + runtime::spawn_blocking(move || run(worker)); + }); + + if had_entered { + // Unset the current task's budget. Blocking sections are not + // constrained by task budgets. + let _reset = Reset(coop::stop()); + + crate::runtime::enter::exit(f) + } else { + f() + } +} + +/// After how many ticks is the global queue polled. This helps to ensure +/// fairness. +/// +/// The number is fairly arbitrary. I believe this value was copied from golang. +const GLOBAL_POLL_INTERVAL: u8 = 61; + +impl Launch { + pub(crate) fn launch(mut self) { + for worker in self.0.drain(..) { + runtime::spawn_blocking(move || run(worker)); + } + } +} + +fn run(worker: Arc<Worker>) { + // Acquire a core. If this fails, then another thread is running this + // worker and there is nothing further to do. + let core = match worker.core.take() { + Some(core) => core, + None => return, + }; + + // Set the worker context. + let cx = Context { + worker, + core: RefCell::new(None), + }; + + let _enter = crate::runtime::enter(true); + + CURRENT.set(&cx, || { + // This should always be an error. It only returns a `Result` to support + // using `?` to short circuit. + assert!(cx.run(core).is_err()); + }); +} + +impl Context { + fn run(&self, mut core: Box<Core>) -> RunResult { + while !core.is_shutdown { + // Increment the tick + core.tick(); + + // Run maintenance, if needed + core = self.maintenance(core); + + // First, check work available to the current worker. + if let Some(task) = core.next_task(&self.worker) { + core = self.run_task(task, core)?; + continue; + } + + // There is no more **local** work to process, try to steal work + // from other workers. + if let Some(task) = core.steal_work(&self.worker) { + core = self.run_task(task, core)?; + } else { + // Wait for work + core = self.park(core); + } + } + + core.pre_shutdown(&self.worker); + + // Signal shutdown + self.worker.shared.shutdown(core); + Err(()) + } + + fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult { + let task = self.worker.shared.owned.assert_owner(task); + + // Make sure the worker is not in the **searching** state. This enables + // another idle worker to try to steal work. + core.transition_from_searching(&self.worker); + + // Make the core available to the runtime context + core.metrics.incr_poll_count(); + *self.core.borrow_mut() = Some(core); + + // Run the task + coop::budget(|| { + task.run(); + + // As long as there is budget remaining and a task exists in the + // `lifo_slot`, then keep running. + loop { + // Check if we still have the core. If not, the core was stolen + // by another worker. + let mut core = match self.core.borrow_mut().take() { + Some(core) => core, + None => return Err(()), + }; + + // Check for a task in the LIFO slot + let task = match core.lifo_slot.take() { + Some(task) => task, + None => return Ok(core), + }; + + if coop::has_budget_remaining() { + // Run the LIFO task, then loop + core.metrics.incr_poll_count(); + *self.core.borrow_mut() = Some(core); + let task = self.worker.shared.owned.assert_owner(task); + task.run(); + } else { + // Not enough budget left to run the LIFO task, push it to + // the back of the queue and return. + core.run_queue + .push_back(task, self.worker.inject(), &mut core.metrics); + return Ok(core); + } + } + }) + } + + fn maintenance(&self, mut core: Box<Core>) -> Box<Core> { + if core.tick % GLOBAL_POLL_INTERVAL == 0 { + // Call `park` with a 0 timeout. This enables the I/O driver, timer, ... + // to run without actually putting the thread to sleep. + core = self.park_timeout(core, Some(Duration::from_millis(0))); + + // Run regularly scheduled maintenance + core.maintenance(&self.worker); + } + + core + } + + fn park(&self, mut core: Box<Core>) -> Box<Core> { + if let Some(f) = &self.worker.shared.before_park { + f(); + } + + if core.transition_to_parked(&self.worker) { + while !core.is_shutdown { + core.metrics.about_to_park(); + core = self.park_timeout(core, None); + core.metrics.returned_from_park(); + + // Run regularly scheduled maintenance + core.maintenance(&self.worker); + + if core.transition_from_parked(&self.worker) { + break; + } + } + } + + if let Some(f) = &self.worker.shared.after_unpark { + f(); + } + core + } + + fn park_timeout(&self, mut core: Box<Core>, duration: Option<Duration>) -> Box<Core> { + // Take the parker out of core + let mut park = core.park.take().expect("park missing"); + + // Store `core` in context + *self.core.borrow_mut() = Some(core); + + // Park thread + if let Some(timeout) = duration { + park.park_timeout(timeout).expect("park failed"); + } else { + park.park().expect("park failed"); + } + + // Remove `core` from context + core = self.core.borrow_mut().take().expect("core missing"); + + // Place `park` back in `core` + core.park = Some(park); + + // If there are tasks available to steal, but this worker is not + // looking for tasks to steal, notify another worker. + if !core.is_searching && core.run_queue.is_stealable() { + self.worker.shared.notify_parked(); + } + + core + } +} + +impl Core { + /// Increment the tick + fn tick(&mut self) { + self.tick = self.tick.wrapping_add(1); + } + + /// Return the next notified task available to this worker. + fn next_task(&mut self, worker: &Worker) -> Option<Notified> { + if self.tick % GLOBAL_POLL_INTERVAL == 0 { + worker.inject().pop().or_else(|| self.next_local_task()) + } else { + self.next_local_task().or_else(|| worker.inject().pop()) + } + } + + fn next_local_task(&mut self) -> Option<Notified> { + self.lifo_slot.take().or_else(|| self.run_queue.pop()) + } + + fn steal_work(&mut self, worker: &Worker) -> Option<Notified> { + if !self.transition_to_searching(worker) { + return None; + } + + let num = worker.shared.remotes.len(); + // Start from a random worker + let start = self.rand.fastrand_n(num as u32) as usize; + + for i in 0..num { + let i = (start + i) % num; + + // Don't steal from ourself! We know we don't have work. + if i == worker.index { + continue; + } + + let target = &worker.shared.remotes[i]; + if let Some(task) = target + .steal + .steal_into(&mut self.run_queue, &mut self.metrics) + { + return Some(task); + } + } + + // Fallback on checking the global queue + worker.shared.inject.pop() + } + + fn transition_to_searching(&mut self, worker: &Worker) -> bool { + if !self.is_searching { + self.is_searching = worker.shared.idle.transition_worker_to_searching(); + } + + self.is_searching + } + + fn transition_from_searching(&mut self, worker: &Worker) { + if !self.is_searching { + return; + } + + self.is_searching = false; + worker.shared.transition_worker_from_searching(); + } + + /// Prepares the worker state for parking. + /// + /// Returns true if the transition happend, false if there is work to do first. + fn transition_to_parked(&mut self, worker: &Worker) -> bool { + // Workers should not park if they have work to do + if self.lifo_slot.is_some() || self.run_queue.has_tasks() { + return false; + } + + // When the final worker transitions **out** of searching to parked, it + // must check all the queues one last time in case work materialized + // between the last work scan and transitioning out of searching. + let is_last_searcher = worker + .shared + .idle + .transition_worker_to_parked(worker.index, self.is_searching); + + // The worker is no longer searching. Setting this is the local cache + // only. + self.is_searching = false; + + if is_last_searcher { + worker.shared.notify_if_work_pending(); + } + + true + } + + /// Returns `true` if the transition happened. + fn transition_from_parked(&mut self, worker: &Worker) -> bool { + // If a task is in the lifo slot, then we must unpark regardless of + // being notified + if self.lifo_slot.is_some() { + // When a worker wakes, it should only transition to the "searching" + // state when the wake originates from another worker *or* a new task + // is pushed. We do *not* want the worker to transition to "searching" + // when it wakes when the I/O driver receives new events. + self.is_searching = !worker.shared.idle.unpark_worker_by_id(worker.index); + return true; + } + + if worker.shared.idle.is_parked(worker.index) { + return false; + } + + // When unparked, the worker is in the searching state. + self.is_searching = true; + true + } + + /// Runs maintenance work such as checking the pool's state. + fn maintenance(&mut self, worker: &Worker) { + self.metrics + .submit(&worker.shared.worker_metrics[worker.index]); + + if !self.is_shutdown { + // Check if the scheduler has been shutdown + self.is_shutdown = worker.inject().is_closed(); + } + } + + /// Signals all tasks to shut down, and waits for them to complete. Must run + /// before we enter the single-threaded phase of shutdown processing. + fn pre_shutdown(&mut self, worker: &Worker) { + // Signal to all tasks to shut down. + worker.shared.owned.close_and_shutdown_all(); + + self.metrics + .submit(&worker.shared.worker_metrics[worker.index]); + } + + /// Shuts down the core. + fn shutdown(&mut self) { + // Take the core + let mut park = self.park.take().expect("park missing"); + + // Drain the queue + while self.next_local_task().is_some() {} + + park.shutdown(); + } +} + +impl Worker { + /// Returns a reference to the scheduler's injection queue. + fn inject(&self) -> &Inject<Arc<Shared>> { + &self.shared.inject + } +} + +impl task::Schedule for Arc<Shared> { + fn release(&self, task: &Task) -> Option<Task> { + self.owned.remove(task) + } + + fn schedule(&self, task: Notified) { + (**self).schedule(task, false); + } + + fn yield_now(&self, task: Notified) { + (**self).schedule(task, true); + } +} + +impl Shared { + pub(super) fn bind_new_task<T>(me: &Arc<Self>, future: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (handle, notified) = me.owned.bind(future, me.clone()); + + if let Some(notified) = notified { + me.schedule(notified, false); + } + + handle + } + + pub(super) fn schedule(&self, task: Notified, is_yield: bool) { + CURRENT.with(|maybe_cx| { + if let Some(cx) = maybe_cx { + // Make sure the task is part of the **current** scheduler. + if self.ptr_eq(&cx.worker.shared) { + // And the current thread still holds a core + if let Some(core) = cx.core.borrow_mut().as_mut() { + self.schedule_local(core, task, is_yield); + return; + } + } + } + + // Otherwise, use the inject queue. + self.inject.push(task); + self.scheduler_metrics.inc_remote_schedule_count(); + self.notify_parked(); + }) + } + + fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) { + core.metrics.inc_local_schedule_count(); + + // Spawning from the worker thread. If scheduling a "yield" then the + // task must always be pushed to the back of the queue, enabling other + // tasks to be executed. If **not** a yield, then there is more + // flexibility and the task may go to the front of the queue. + let should_notify = if is_yield { + core.run_queue + .push_back(task, &self.inject, &mut core.metrics); + true + } else { + // Push to the LIFO slot + let prev = core.lifo_slot.take(); + let ret = prev.is_some(); + + if let Some(prev) = prev { + core.run_queue + .push_back(prev, &self.inject, &mut core.metrics); + } + + core.lifo_slot = Some(task); + + ret + }; + + // Only notify if not currently parked. If `park` is `None`, then the + // scheduling is from a resource driver. As notifications often come in + // batches, the notification is delayed until the park is complete. + if should_notify && core.park.is_some() { + self.notify_parked(); + } + } + + pub(super) fn close(&self) { + if self.inject.close() { + self.notify_all(); + } + } + + fn notify_parked(&self) { + if let Some(index) = self.idle.worker_to_notify() { + self.remotes[index].unpark.unpark(); + } + } + + fn notify_all(&self) { + for remote in &self.remotes[..] { + remote.unpark.unpark(); + } + } + + fn notify_if_work_pending(&self) { + for remote in &self.remotes[..] { + if !remote.steal.is_empty() { + self.notify_parked(); + return; + } + } + + if !self.inject.is_empty() { + self.notify_parked(); + } + } + + fn transition_worker_from_searching(&self) { + if self.idle.transition_worker_from_searching() { + // We are the final searching worker. Because work was found, we + // need to notify another worker. + self.notify_parked(); + } + } + + /// Signals that a worker has observed the shutdown signal and has replaced + /// its core back into its handle. + /// + /// If all workers have reached this point, the final cleanup is performed. + fn shutdown(&self, core: Box<Core>) { + let mut cores = self.shutdown_cores.lock(); + cores.push(core); + + if cores.len() != self.remotes.len() { + return; + } + + debug_assert!(self.owned.is_empty()); + + for mut core in cores.drain(..) { + core.shutdown(); + } + + // Drain the injection queue + // + // We already shut down every task, so we can simply drop the tasks. + while let Some(task) = self.inject.pop() { + drop(task); + } + } + + fn ptr_eq(&self, other: &Shared) -> bool { + std::ptr::eq(self, other) + } +} + +cfg_metrics! { + impl Shared { + pub(super) fn injection_queue_depth(&self) -> usize { + self.inject.len() + } + + pub(super) fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.remotes[worker].steal.len() + } + } +} |