diff options
Diffstat (limited to 'third_party/rust/tokio/src/runtime/task')
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/core.rs | 267 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/error.rs | 146 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/harness.rs | 485 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/inject.rs | 220 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/join.rs | 275 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/list.rs | 297 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/mod.rs | 445 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/raw.rs | 165 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/state.rs | 595 | ||||
-rw-r--r-- | third_party/rust/tokio/src/runtime/task/waker.rs | 130 |
10 files changed, 3025 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/runtime/task/core.rs b/third_party/rust/tokio/src/runtime/task/core.rs new file mode 100644 index 0000000000..776e8341f3 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/core.rs @@ -0,0 +1,267 @@ +//! Core task module. +//! +//! # Safety +//! +//! The functions in this module are private to the `task` module. All of them +//! should be considered `unsafe` to use, but are not marked as such since it +//! would be too noisy. +//! +//! Make sure to consult the relevant safety section of each function before +//! use. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::runtime::task::raw::{self, Vtable}; +use crate::runtime::task::state::State; +use crate::runtime::task::Schedule; +use crate::util::linked_list; + +use std::pin::Pin; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// The task cell. Contains the components of the task. +/// +/// It is critical for `Header` to be the first field as the task structure will +/// be referenced by both *mut Cell and *mut Header. +#[repr(C)] +pub(super) struct Cell<T: Future, S> { + /// Hot task state data + pub(super) header: Header, + + /// Either the future or output, depending on the execution stage. + pub(super) core: Core<T, S>, + + /// Cold data + pub(super) trailer: Trailer, +} + +pub(super) struct CoreStage<T: Future> { + stage: UnsafeCell<Stage<T>>, +} + +/// The core of the task. +/// +/// Holds the future or output, depending on the stage of execution. +pub(super) struct Core<T: Future, S> { + /// Scheduler used to drive this future. + pub(super) scheduler: S, + + /// Either the future or the output. + pub(super) stage: CoreStage<T>, +} + +/// Crate public as this is also needed by the pool. +#[repr(C)] +pub(crate) struct Header { + /// Task state. + pub(super) state: State, + + pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, + + /// Pointer to next task, used with the injection queue. + pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, + + /// Table of function pointers for executing actions on the task. + pub(super) vtable: &'static Vtable, + + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell<u64>, + + /// The tracing ID for this instrumented task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) id: Option<tracing::Id>, +} + +unsafe impl Send for Header {} +unsafe impl Sync for Header {} + +/// Cold data is stored after the future. +pub(super) struct Trailer { + /// Consumer task waiting on completion of this task. + pub(super) waker: UnsafeCell<Option<Waker>>, +} + +/// Either the future or the output. +pub(super) enum Stage<T: Future> { + Running(T), + Finished(super::Result<T::Output>), + Consumed, +} + +impl<T: Future, S: Schedule> Cell<T, S> { + /// Allocates a new task cell, containing the header, trailer, and core + /// structures. + pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let id = future.id(); + Box::new(Cell { + header: Header { + state, + owned: UnsafeCell::new(linked_list::Pointers::new()), + queue_next: UnsafeCell::new(None), + vtable: raw::vtable::<T, S>(), + owner_id: UnsafeCell::new(0), + #[cfg(all(tokio_unstable, feature = "tracing"))] + id, + }, + core: Core { + scheduler, + stage: CoreStage { + stage: UnsafeCell::new(Stage::Running(future)), + }, + }, + trailer: Trailer { + waker: UnsafeCell::new(None), + }, + }) + } +} + +impl<T: Future> CoreStage<T> { + pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { + self.stage.with_mut(f) + } + + /// Polls the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `state` field. This + /// requires ensuring mutual exclusion between any concurrent thread that + /// might modify the future or output field. + /// + /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` + /// component of the task state. + /// + /// `self` must also be pinned. This is handled by storing the task on the + /// heap. + pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> { + let res = { + self.stage.with_mut(|ptr| { + // Safety: The caller ensures mutual exclusion to the field. + let future = match unsafe { &mut *ptr } { + Stage::Running(future) => future, + _ => unreachable!("unexpected stage"), + }; + + // Safety: The caller ensures the future is pinned. + let future = unsafe { Pin::new_unchecked(future) }; + + future.poll(&mut cx) + }) + }; + + if res.is_ready() { + self.drop_future_or_output(); + } + + res + } + + /// Drops the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn drop_future_or_output(&self) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Consumed); + } + } + + /// Stores the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn store_output(&self, output: super::Result<T::Output>) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Finished(output)); + } + } + + /// Takes the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn take_output(&self) -> super::Result<T::Output> { + use std::mem; + + self.stage.with_mut(|ptr| { + // Safety:: the caller ensures mutual exclusion to the field. + match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { + Stage::Finished(output) => output, + _ => panic!("JoinHandle polled after completion"), + } + }) + } + + unsafe fn set_stage(&self, stage: Stage<T>) { + self.stage.with_mut(|ptr| *ptr = stage) + } +} + +cfg_rt_multi_thread! { + impl Header { + pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { + self.queue_next.with_mut(|ptr| *ptr = next); + } + } +} + +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + +impl Trailer { + pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { + self.waker.with_mut(|ptr| { + *ptr = waker; + }); + } + + pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { + self.waker + .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) + } + + pub(super) fn wake_join(&self) { + self.waker.with(|ptr| match unsafe { &*ptr } { + Some(waker) => waker.wake_by_ref(), + None => panic!("waker missing"), + }); + } +} + +#[test] +#[cfg(not(loom))] +fn header_lte_cache_line() { + use std::mem::size_of; + + assert!(size_of::<Header>() <= 8 * size_of::<*const ()>()); +} diff --git a/third_party/rust/tokio/src/runtime/task/error.rs b/third_party/rust/tokio/src/runtime/task/error.rs new file mode 100644 index 0000000000..1a8129b2b6 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/error.rs @@ -0,0 +1,146 @@ +use std::any::Any; +use std::fmt; +use std::io; + +use crate::util::SyncWrapper; + +cfg_rt! { + /// Task failed to execute to completion. + pub struct JoinError { + repr: Repr, + } +} + +enum Repr { + Cancelled, + Panic(SyncWrapper<Box<dyn Any + Send + 'static>>), +} + +impl JoinError { + pub(crate) fn cancelled() -> JoinError { + JoinError { + repr: Repr::Cancelled, + } + } + + pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { + JoinError { + repr: Repr::Panic(SyncWrapper::new(err)), + } + } + + /// Returns true if the error was caused by the task being cancelled. + pub fn is_cancelled(&self) -> bool { + matches!(&self.repr, Repr::Cancelled) + } + + /// Returns true if the error was caused by the task panicking. + /// + /// # Examples + /// + /// ``` + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// assert!(err.is_panic()); + /// } + /// ``` + pub fn is_panic(&self) -> bool { + matches!(&self.repr, Repr::Panic(_)) + } + + /// Consumes the join error, returning the object with which the task panicked. + /// + /// # Panics + /// + /// `into_panic()` panics if the `Error` does not represent the underlying + /// task terminating with a panic. Use `is_panic` to check the error reason + /// or `try_into_panic` for a variant that does not panic. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if err.is_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(err.into_panic()); + /// } + /// } + /// ``` + pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { + self.try_into_panic() + .expect("`JoinError` reason is not a panic.") + } + + /// Consumes the join error, returning the object with which the task + /// panicked if the task terminated due to a panic. Otherwise, `self` is + /// returned. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if let Ok(reason) = err.try_into_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(reason); + /// } + /// } + /// ``` + pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> { + match self.repr { + Repr::Panic(p) => Ok(p.into_inner()), + _ => Err(self), + } + } +} + +impl fmt::Display for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "cancelled"), + Repr::Panic(_) => write!(fmt, "panic"), + } + } +} + +impl fmt::Debug for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), + Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + } + } +} + +impl std::error::Error for JoinError {} + +impl From<JoinError> for io::Error { + fn from(src: JoinError) -> io::Error { + io::Error::new( + io::ErrorKind::Other, + match src.repr { + Repr::Cancelled => "task was cancelled", + Repr::Panic(_) => "task panicked", + }, + ) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/harness.rs b/third_party/rust/tokio/src/runtime/task/harness.rs new file mode 100644 index 0000000000..261dccea41 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/harness.rs @@ -0,0 +1,485 @@ +use crate::future::Future; +use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; +use crate::runtime::task::state::Snapshot; +use crate::runtime::task::waker::waker_ref; +use crate::runtime::task::{JoinError, Notified, Schedule, Task}; + +use std::mem; +use std::mem::ManuallyDrop; +use std::panic; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// Typed raw task handle. +pub(super) struct Harness<T: Future, S: 'static> { + cell: NonNull<Cell<T, S>>, +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: 'static, +{ + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> { + Harness { + cell: ptr.cast::<Cell<T, S>>(), + } + } + + fn header_ptr(&self) -> NonNull<Header> { + self.cell.cast() + } + + fn header(&self) -> &Header { + unsafe { &self.cell.as_ref().header } + } + + fn trailer(&self) -> &Trailer { + unsafe { &self.cell.as_ref().trailer } + } + + fn core(&self) -> &Core<T, S> { + unsafe { &self.cell.as_ref().core } + } +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: Schedule, +{ + /// Polls the inner future. A ref-count is consumed. + /// + /// All necessary state checks and transitions are performed. + /// Panics raised while polling the future are handled. + pub(super) fn poll(self) { + // We pass our ref-count to `poll_inner`. + match self.poll_inner() { + PollFuture::Notified => { + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + self.core() + .scheduler + .yield_now(Notified(self.get_new_task())); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. + self.drop_reference(); + } + PollFuture::Complete => { + self.complete(); + } + PollFuture::Dealloc => { + self.dealloc(); + } + PollFuture::Done => (), + } + } + + /// Polls the task and cancel it if necessary. This takes ownership of a + /// ref-count. + /// + /// If the return value is Notified, the caller is given ownership of two + /// ref-counts. + /// + /// If the return value is Complete, the caller is given ownership of a + /// single ref-count, which should be passed on to `complete`. + /// + /// If the return value is Dealloc, then this call consumed the last + /// ref-count and the caller should call `dealloc`. + /// + /// Otherwise the ref-count is consumed and the caller should not access + /// `self` again. + fn poll_inner(&self) -> PollFuture { + use super::state::{TransitionToIdle, TransitionToRunning}; + + match self.header().state.transition_to_running() { + TransitionToRunning::Success => { + let header_ptr = self.header_ptr(); + let waker_ref = waker_ref::<T, S>(&header_ptr); + let cx = Context::from_waker(&*waker_ref); + let res = poll_future(&self.core().stage, cx); + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollFuture::Complete; + } + + match self.header().state.transition_to_idle() { + TransitionToIdle::Ok => PollFuture::Done, + TransitionToIdle::OkNotified => PollFuture::Notified, + TransitionToIdle::OkDealloc => PollFuture::Dealloc, + TransitionToIdle::Cancelled => { + // The transition to idle failed because the task was + // cancelled during the poll. + + cancel_task(&self.core().stage); + PollFuture::Complete + } + } + } + TransitionToRunning::Cancelled => { + cancel_task(&self.core().stage); + PollFuture::Complete + } + TransitionToRunning::Failed => PollFuture::Done, + TransitionToRunning::Dealloc => PollFuture::Dealloc, + } + } + + /// Forcibly shuts down the task. + /// + /// Attempt to transition to `Running` in order to forcibly shutdown the + /// task. If the task is currently running or in a state of completion, then + /// there is nothing further to do. When the task completes running, it will + /// notice the `CANCELLED` bit and finalize the task. + pub(super) fn shutdown(self) { + if !self.header().state.transition_to_shutdown() { + // The task is concurrently running. No further work needed. + self.drop_reference(); + return; + } + + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(&self.core().stage); + self.complete(); + } + + pub(super) fn dealloc(self) { + // Release the join waker, if there is one. + self.trailer().waker.with_mut(drop); + + // Check causality + self.core().stage.with_mut(drop); + + unsafe { + drop(Box::from_raw(self.cell.as_ptr())); + } + } + + // ===== join handle ===== + + /// Read the task output into `dst`. + pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) { + if can_read_output(self.header(), self.trailer(), waker) { + *dst = Poll::Ready(self.core().stage.take_output()); + } + } + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + can_read_output(self.header(), self.trailer(), waker) + } + + pub(super) fn drop_join_handle_slow(self) { + // Try to unset `JOIN_INTEREST`. This must be done as a first step in + // case the task concurrently completed. + if self.header().state.unset_join_interested().is_err() { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + } + + // Drop the `JoinHandle` reference, possibly deallocating the task + self.drop_reference(); + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(self) { + if self.header().state.transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + } + + // ===== waker behavior ===== + + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. + pub(super) fn wake_by_val(self) { + use super::state::TransitionToNotifiedByVal; + + match self.header().state.transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } + } + + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. + pub(super) fn wake_by_ref(&self) { + use super::state::TransitionToNotifiedByRef; + + match self.header().state.transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + TransitionToNotifiedByRef::DoNothing => {} + } + } + + pub(super) fn drop_reference(self) { + if self.header().state.ref_dec() { + self.dealloc(); + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) fn id(&self) -> Option<&tracing::Id> { + self.header().id.as_ref() + } + + // ====== internal ====== + + /// Completes the task. This method assumes that the state is RUNNING. + fn complete(self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + + let snapshot = self.header().state.transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. + self.core().stage.drop_future_or_output(); + } else if snapshot.has_join_waker() { + // Notify the join handle. The previous transition obtains the + // lock on the waker cell. + self.trailer().wake_join(); + } + })); + + // The task has completed execution and will no longer be scheduled. + let num_release = self.release(); + + if self.header().state.transition_to_terminal(num_release) { + self.dealloc(); + } + } + + /// Releases the task from the scheduler. Returns the number of ref-counts + /// that should be decremented. + fn release(&self) -> usize { + // We don't actually increment the ref-count here, but the new task is + // never destroyed, so that's ok. + let me = ManuallyDrop::new(self.get_new_task()); + + if let Some(task) = self.core().scheduler.release(&me) { + mem::forget(task); + 2 + } else { + 1 + } + } + + /// Creates a new task that holds its own ref-count. + /// + /// # Safety + /// + /// Any use of `self` after this call must ensure that a ref-count to the + /// task holds the task alive until after the use of `self`. Passing the + /// returned Task to any method on `self` is unsound if dropping the Task + /// could drop `self` before the call on `self` returned. + fn get_new_task(&self) -> Task<S> { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { Task::from_raw(self.cell.cast()) } + } +} + +fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { + // Load a snapshot of the current task state + let snapshot = header.state.load(); + + debug_assert!(snapshot.is_join_interested()); + + if !snapshot.is_complete() { + // The waker must be stored in the task struct. + let res = if snapshot.has_join_waker() { + // There already is a waker stored in the struct. If it matches + // the provided waker, then there is no further work to do. + // Otherwise, the waker must be swapped. + let will_wake = unsafe { + // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE` + // may mutate the `waker` field. + trailer.will_wake(waker) + }; + + if will_wake { + // The task is not complete **and** the waker is up to date, + // there is nothing further that needs to be done. + return false; + } + + // Unset the `JOIN_WAKER` to gain mutable access to the `waker` + // field then update the field with the new join worker. + // + // This requires two atomic operations, unsetting the bit and + // then resetting it. If the task transitions to complete + // concurrently to either one of those operations, then setting + // the join waker fails and we proceed to reading the task + // output. + header + .state + .unset_waker() + .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot)) + } else { + set_join_waker(header, trailer, waker.clone(), snapshot) + }; + + match res { + Ok(_) => return false, + Err(snapshot) => { + assert!(snapshot.is_complete()); + } + } + } + true +} + +fn set_join_waker( + header: &Header, + trailer: &Trailer, + waker: Waker, + snapshot: Snapshot, +) -> Result<Snapshot, Snapshot> { + assert!(snapshot.is_join_interested()); + assert!(!snapshot.has_join_waker()); + + // Safety: Only the `JoinHandle` may set the `waker` field. When + // `JOIN_INTEREST` is **not** set, nothing else will touch the field. + unsafe { + trailer.set_waker(Some(waker)); + } + + // Update the `JoinWaker` state accordingly + let res = header.state.set_join_waker(); + + // If the state could not be updated, then clear the join waker + if res.is_err() { + unsafe { + trailer.set_waker(None); + } + } + + res +} + +enum PollFuture { + Complete, + Notified, + Done, + Dealloc, +} + +/// Cancels the task and store the appropriate error in the stage field. +fn cancel_task<T: Future>(stage: &CoreStage<T>) { + // Drop the future from a panic guard. + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + stage.drop_future_or_output(); + })); + + match res { + Ok(()) => { + stage.store_output(Err(JoinError::cancelled())); + } + Err(panic) => { + stage.store_output(Err(JoinError::panic(panic))); + } + } +} + +/// Polls the future. If the future completes, the output is written to the +/// stage field. +fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage<T>, + } + impl<'a, T: Future> Drop for Guard<'a, T> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + self.core.drop_future_or_output(); + } + } + let guard = Guard { core }; + let res = guard.core.poll(cx); + mem::forget(guard); + res + })); + + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(JoinError::panic(panic)), + }; + + // Catch and ignore panics if the future panics on drop. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + core.store_output(output); + })); + + Poll::Ready(()) +} diff --git a/third_party/rust/tokio/src/runtime/task/inject.rs b/third_party/rust/tokio/src/runtime/task/inject.rs new file mode 100644 index 0000000000..1585e13a01 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/inject.rs @@ -0,0 +1,220 @@ +//! Inject queue used to send wakeups to a work-stealing scheduler + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::runtime::task; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{Acquire, Release}; + +/// Growable, MPMC queue used to inject new tasks into the scheduler and as an +/// overflow queue when the local, fixed-size, array queue overflows. +pub(crate) struct Inject<T: 'static> { + /// Pointers to the head and tail of the queue. + pointers: Mutex<Pointers>, + + /// Number of pending tasks in the queue. This helps prevent unnecessary + /// locking in the hot path. + len: AtomicUsize, + + _p: PhantomData<T>, +} + +struct Pointers { + /// True if the queue is closed. + is_closed: bool, + + /// Linked-list head. + head: Option<NonNull<task::Header>>, + + /// Linked-list tail. + tail: Option<NonNull<task::Header>>, +} + +unsafe impl<T> Send for Inject<T> {} +unsafe impl<T> Sync for Inject<T> {} + +impl<T: 'static> Inject<T> { + pub(crate) fn new() -> Inject<T> { + Inject { + pointers: Mutex::new(Pointers { + is_closed: false, + head: None, + tail: None, + }), + len: AtomicUsize::new(0), + _p: PhantomData, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Closes the injection queue, returns `true` if the queue is open when the + /// transition is made. + pub(crate) fn close(&self) -> bool { + let mut p = self.pointers.lock(); + + if p.is_closed { + return false; + } + + p.is_closed = true; + true + } + + pub(crate) fn is_closed(&self) -> bool { + self.pointers.lock().is_closed + } + + pub(crate) fn len(&self) -> usize { + self.len.load(Acquire) + } + + /// Pushes a value into the queue. + /// + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified<T>) { + // Acquire queue lock + let mut p = self.pointers.lock(); + + if p.is_closed { + return; + } + + // safety: only mutated with the lock held + let len = unsafe { self.len.unsync_load() }; + let task = task.into_raw(); + + // The next pointer should already be null + debug_assert!(get_next(task).is_none()); + + if let Some(tail) = p.tail { + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(tail, Some(task)); + } else { + p.head = Some(task); + } + + p.tail = Some(task); + + self.len.store(len + 1, Release); + } + + /// Pushes several values into the queue. + #[inline] + pub(crate) fn push_batch<I>(&self, mut iter: I) + where + I: Iterator<Item = task::Notified<T>>, + { + let first = match iter.next() { + Some(first) => first.into_raw(), + None => return, + }; + + // Link up all the tasks. + let mut prev = first; + let mut counter = 1; + + // We are going to be called with an `std::iter::Chain`, and that + // iterator overrides `for_each` to something that is easier for the + // compiler to optimize than a loop. + iter.for_each(|next| { + let next = next.into_raw(); + + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(prev, Some(next)); + prev = next; + counter += 1; + }); + + // Now that the tasks are linked together, insert them into the + // linked list. + self.push_batch_inner(first, prev, counter); + } + + /// Inserts several tasks that have been linked together into the queue. + /// + /// The provided head and tail may be be the same task. In this case, a + /// single task is inserted. + #[inline] + fn push_batch_inner( + &self, + batch_head: NonNull<task::Header>, + batch_tail: NonNull<task::Header>, + num: usize, + ) { + debug_assert!(get_next(batch_tail).is_none()); + + let mut p = self.pointers.lock(); + + if let Some(tail) = p.tail { + set_next(tail, Some(batch_head)); + } else { + p.head = Some(batch_head); + } + + p.tail = Some(batch_tail); + + // Increment the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + let len = unsafe { self.len.unsync_load() }; + + self.len.store(len + num, Release); + } + + pub(crate) fn pop(&self) -> Option<task::Notified<T>> { + // Fast path, if len == 0, then there are no values + if self.is_empty() { + return None; + } + + let mut p = self.pointers.lock(); + + // It is possible to hit null here if another thread popped the last + // task between us checking `len` and acquiring the lock. + let task = p.head?; + + p.head = get_next(task); + + if p.head.is_none() { + p.tail = None; + } + + set_next(task, None); + + // Decrement the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + self.len + .store(unsafe { self.len.unsync_load() } - 1, Release); + + // safety: a `Notified` is pushed into the queue and now it is popped! + Some(unsafe { task::Notified::from_raw(task) }) + } +} + +impl<T: 'static> Drop for Inject<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { + unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } +} + +fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { + unsafe { + header.as_ref().set_next(val); + } +} diff --git a/third_party/rust/tokio/src/runtime/task/join.rs b/third_party/rust/tokio/src/runtime/task/join.rs new file mode 100644 index 0000000000..8beed2eaac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/join.rs @@ -0,0 +1,275 @@ +use crate::runtime::task::RawTask; + +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; + +cfg_rt! { + /// An owned permission to join on a task (await its termination). + /// + /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for + /// a task rather than a thread. + /// + /// A `JoinHandle` *detaches* the associated task when it is dropped, which + /// means that there is no longer any handle to the task, and no way to `join` + /// on it. + /// + /// This `struct` is created by the [`task::spawn`] and [`task::spawn_blocking`] + /// functions. + /// + /// # Examples + /// + /// Creation from [`task::spawn`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn(async { + /// // some work here + /// }); + /// # } + /// ``` + /// + /// Creation from [`task::spawn_blocking`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn_blocking(|| { + /// // some blocking work here + /// }); + /// # } + /// ``` + /// + /// The generic parameter `T` in `JoinHandle<T>` is the return type of the spawned task. + /// If the return value is an i32, the join handle has type `JoinHandle<i32>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<i32> = task::spawn(async { + /// 5 + 3 + /// }); + /// # } + /// + /// ``` + /// + /// If the task does not have a return value, the join handle has type `JoinHandle<()>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<()> = task::spawn(async { + /// println!("I return nothing."); + /// }); + /// # } + /// ``` + /// + /// Note that `handle.await` doesn't give you the return type directly. It is wrapped in a + /// `Result` because panics in the spawned task are caught by Tokio. The `?` operator has + /// to be double chained to extract the returned value: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// Ok(5 + 3) + /// }); + /// + /// let result = join_handle.await??; + /// assert_eq!(result, 8); + /// Ok(()) + /// } + /// ``` + /// + /// If the task panics, the error is a [`JoinError`] that contains the panic: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// panic!("boom"); + /// }); + /// + /// let err = join_handle.await.unwrap_err(); + /// assert!(err.is_panic()); + /// Ok(()) + /// } + /// + /// ``` + /// Child being detached and outliving its parent: + /// + /// ```no_run + /// use tokio::task; + /// use tokio::time; + /// use std::time::Duration; + /// + /// # #[tokio::main] async fn main() { + /// let original_task = task::spawn(async { + /// let _detached_task = task::spawn(async { + /// // Here we sleep to make sure that the first task returns before. + /// time::sleep(Duration::from_millis(10)).await; + /// // This will be called, even though the JoinHandle is dropped. + /// println!("♫ Still alive ♫"); + /// }); + /// }); + /// + /// original_task.await.expect("The task being joined has panicked"); + /// println!("Original task is joined."); + /// + /// // We make sure that the new task has time to run, before the main + /// // task returns. + /// + /// time::sleep(Duration::from_millis(1000)).await; + /// # } + /// ``` + /// + /// [`task::spawn`]: crate::task::spawn() + /// [`task::spawn_blocking`]: crate::task::spawn_blocking + /// [`std::thread::JoinHandle`]: std::thread::JoinHandle + /// [`JoinError`]: crate::task::JoinError + pub struct JoinHandle<T> { + raw: Option<RawTask>, + _p: PhantomData<T>, + } +} + +unsafe impl<T: Send> Send for JoinHandle<T> {} +unsafe impl<T: Send> Sync for JoinHandle<T> {} + +impl<T> UnwindSafe for JoinHandle<T> {} +impl<T> RefUnwindSafe for JoinHandle<T> {} + +impl<T> JoinHandle<T> { + pub(super) fn new(raw: RawTask) -> JoinHandle<T> { + JoinHandle { + raw: Some(raw), + _p: PhantomData, + } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// true + /// })); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// false + /// })); + /// + /// for handle in &handles { + /// handle.abort(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + /// [cancelled]: method@super::error::JoinError::is_cancelled + pub fn abort(&self) { + if let Some(raw) = self.raw { + raw.remote_abort(); + } + } + + /// Set the waker that is notified when the task completes. + pub(crate) fn set_join_waker(&mut self, waker: &Waker) { + if let Some(raw) = self.raw { + if raw.try_set_join_waker(waker) { + // In this case the task has already completed. We wake the waker immediately. + waker.wake_by_ref(); + } + } + } +} + +impl<T> Unpin for JoinHandle<T> {} + +impl<T> Future for JoinHandle<T> { + type Output = super::Result<T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut ret = Poll::Pending; + + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + // Raw should always be set. If it is not, this is due to polling after + // completion + let raw = self + .raw + .as_ref() + .expect("polling after `JoinHandle` already completed"); + + // Try to read the task output. If the task is not yet complete, the + // waker is stored and is notified once the task does complete. + // + // The function must go via the vtable, which requires erasing generic + // types. To do this, the function "return" is placed on the stack + // **before** calling the function and is passed into the function using + // `*mut ()`. + // + // Safety: + // + // The type of `T` must match the task's output type. + unsafe { + raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); + } + + if ret.is_ready() { + coop.made_progress(); + } + + ret + } +} + +impl<T> Drop for JoinHandle<T> { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + if raw.header().state.drop_join_handle_fast().is_ok() { + return; + } + + raw.drop_join_handle_slow(); + } + } +} + +impl<T> fmt::Debug for JoinHandle<T> +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("JoinHandle").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/list.rs b/third_party/rust/tokio/src/runtime/task/list.rs new file mode 100644 index 0000000000..7758f8db7a --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/list.rs @@ -0,0 +1,297 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::Mutex; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; +use crate::util::linked_list::{Link, LinkedList}; + +use std::marker::PhantomData; + +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + +pub(crate) struct OwnedTasks<S: 'static> { + inner: Mutex<OwnedTasksInner<S>>, + id: u64, +} +pub(crate) struct LocalOwnedTasks<S: 'static> { + inner: UnsafeCell<OwnedTasksInner<S>>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, +} +struct OwnedTasksInner<S: 'static> { + list: LinkedList<Task<S>, <Task<S> as Link>::Target>, + closed: bool, +} + +impl<S: 'static> OwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + } + } + + /// Binds the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(notified); + task.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } + } + + /// Asserts that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + // The first iteration of the loop was unrolled so it can set the + // closed bool. + let first_task = { + let mut lock = self.inner.lock(); + lock.closed = true; + lock.list.pop_back() + }; + match first_task { + Some(task) => task.shutdown(), + None => return, + } + + loop { + let task = match self.inner.lock().list.pop_back() { + Some(task) => task, + None => return, + }; + + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } +} + +impl<S: 'static> LocalOwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + _not_send_or_sync: PhantomData, + } + } + + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + if self.is_closed() { + drop(notified); + task.shutdown(); + (join, None) + } else { + self.with_inner(|inner| { + inner.list.push_front(task); + }); + (join, Some(notified)) + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + self.with_inner(|inner| inner.closed = true); + + while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + self.with_inner(|inner| + // safety: We just checked that the provided task is not in some + // other linked list. + unsafe { inner.list.remove(task.header().into()) }) + } + + /// Asserts that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + #[inline] + fn with_inner<F, T>(&self, f: F) -> T + where + F: FnOnce(&mut OwnedTasksInner<S>) -> T, + { + // safety: This type is not Sync, so concurrent calls of this method + // can't happen. Furthermore, all uses of this method in this file make + // sure that they don't call `with_inner` recursively. + self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) + } + + pub(crate) fn is_closed(&self) -> bool { + self.with_inner(|inner| inner.closed) + } + + pub(crate) fn is_empty(&self) -> bool { + self.with_inner(|inner| inner.list.is_empty()) + } +} + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/third_party/rust/tokio/src/runtime/task/mod.rs b/third_party/rust/tokio/src/runtime/task/mod.rs new file mode 100644 index 0000000000..2a492dc985 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/mod.rs @@ -0,0 +1,445 @@ +//! The task module. +//! +//! The task module contains the code that manages spawned tasks and provides a +//! safe API for the rest of the runtime to use. Each task in a runtime is +//! stored in an OwnedTasks or LocalOwnedTasks object. +//! +//! # Task reference types +//! +//! A task is usually referenced by multiple handles, and there are several +//! types of handles. +//! +//! * OwnedTask - tasks stored in an OwnedTasks or LocalOwnedTasks are of this +//! reference type. +//! +//! * JoinHandle - each task has a JoinHandle that allows access to the output +//! of the task. +//! +//! * Waker - every waker for a task has this reference type. There can be any +//! number of waker references. +//! +//! * Notified - tracks whether the task is notified. +//! +//! * Unowned - this task reference type is used for tasks not stored in any +//! runtime. Mainly used for blocking tasks, but also in tests. +//! +//! The task uses a reference count to keep track of how many active references +//! exist. The Unowned reference type takes up two ref-counts. All other +//! reference types take up a single ref-count. +//! +//! Besides the waker type, each task has at most one of each reference type. +//! +//! # State +//! +//! The task stores its state in an atomic usize with various bitfields for the +//! necessary information. The state has the following bitfields: +//! +//! * RUNNING - Tracks whether the task is currently being polled or cancelled. +//! This bit functions as a lock around the task. +//! +//! * COMPLETE - Is one once the future has fully completed and has been +//! dropped. Never unset once set. Never set together with RUNNING. +//! +//! * NOTIFIED - Tracks whether a Notified object currently exists. +//! +//! * CANCELLED - Is set to one for tasks that should be cancelled as soon as +//! possible. May take any value for completed tasks. +//! +//! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. +//! +//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! +//! The rest of the bits are used for the ref-count. +//! +//! # Fields in the task +//! +//! The task has various fields. This section describes how and when it is safe +//! to access a field. +//! +//! * The state field is accessed with atomic instructions. +//! +//! * The OwnedTask reference has exclusive access to the `owned` field. +//! +//! * The Notified reference has exclusive access to the `queue_next` field. +//! +//! * The `owner_id` field can be set as part of construction of the task, but +//! is otherwise immutable and anyone can access the field immutably without +//! synchronization. +//! +//! * If COMPLETE is one, then the JoinHandle has exclusive access to the +//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as +//! a lock for the stage field, and it can be accessed only by the thread +//! that set RUNNING to one. +//! +//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the +//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the +//! thread that set COMPLETE to one has exclusive access to the join handle +//! waker. +//! +//! All other fields are immutable and can be accessed immutably without +//! synchronization by anyone. +//! +//! # Safety +//! +//! This section goes through various situations and explains why the API is +//! safe in that situation. +//! +//! ## Polling or dropping the future +//! +//! Any mutable access to the future happens after obtaining a lock by modifying +//! the RUNNING field, so exclusive access is ensured. +//! +//! When the task completes, exclusive access to the output is transferred to +//! the JoinHandle. If the JoinHandle is already dropped when the transition to +//! complete happens, the thread performing that transition retains exclusive +//! access to the output and should immediately drop it. +//! +//! ## Non-Send futures +//! +//! If a future is not Send, then it is bound to a LocalOwnedTasks. The future +//! will only ever be polled or dropped given a LocalNotified or inside a call +//! to LocalOwnedTasks::shutdown_all. In either case, it is guaranteed that the +//! future is on the right thread. +//! +//! If the task is never removed from the LocalOwnedTasks, then it is leaked, so +//! there is no risk that the task is dropped on some other thread when the last +//! ref-count drops. +//! +//! ## Non-Send output +//! +//! When a task completes, the output is placed in the stage of the task. Then, +//! a transition that sets COMPLETE to true is performed, and the value of +//! JOIN_INTEREST when this transition happens is read. +//! +//! If JOIN_INTEREST is zero when the transition to COMPLETE happens, then the +//! output is immediately dropped. +//! +//! If JOIN_INTEREST is one when the transition to COMPLETE happens, then the +//! JoinHandle is responsible for cleaning up the output. If the output is not +//! Send, then this happens: +//! +//! 1. The output is created on the thread that the future was polled on. Since +//! only non-Send futures can have non-Send output, the future was polled on +//! the thread that the future was spawned from. +//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! JoinHandle is also on the thread that the future was spawned from. +//! 3. Thus, the JoinHandle will not move the output across threads when it +//! takes or drops the output. +//! +//! ## Recursive poll/shutdown +//! +//! Calling poll from inside a shutdown call or vice-versa is not prevented by +//! the API exposed by the task module, so this has to be safe. In either case, +//! the lock in the RUNNING bitfield makes the inner call return immediately. If +//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the +//! poll call will notice it when the poll finishes, and the task is cancelled +//! at that point. + +// Some task infrastructure is here to support `JoinSet`, which is currently +// unstable. This should be removed once `JoinSet` is stabilized. +#![cfg_attr(not(tokio_unstable), allow(dead_code))] + +mod core; +use self::core::Cell; +use self::core::Header; + +mod error; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::error::JoinError; + +mod harness; +use self::harness::Harness; + +cfg_rt_multi_thread! { + mod inject; + pub(super) use self::inject::Inject; +} + +mod join; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::join::JoinHandle; + +mod list; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; + +mod raw; +use self::raw::RawTask; + +mod state; +use self::state::State; + +mod waker; + +use crate::future::Future; +use crate::util::linked_list; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::{fmt, mem}; + +/// An owned handle to the task, tracked by ref count. +#[repr(transparent)] +pub(crate) struct Task<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +unsafe impl<S> Send for Task<S> {} +unsafe impl<S> Sync for Task<S> {} + +/// A task was notified. +#[repr(transparent)] +pub(crate) struct Notified<S: 'static>(Task<S>); + +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. +unsafe impl<S: Schedule> Send for Notified<S> {} +unsafe impl<S: Schedule> Sync for Notified<S> {} + +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified<S: 'static> { + task: Task<S>, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +// safety: This type can only be created given a Send task. +unsafe impl<S> Send for UnownedTask<S> {} +unsafe impl<S> Sync for UnownedTask<S> {} + +/// Task result sent back. +pub(crate) type Result<T> = std::result::Result<T, JoinError>; + +pub(crate) trait Schedule: Sync + Sized + 'static { + /// The task has completed work and is ready to be released. The scheduler + /// should release it immediately and return it. The task module will batch + /// the ref-dec with setting other options. + /// + /// If the scheduler has already released the task, then None is returned. + fn release(&self, task: &Task<Self>) -> Option<Task<Self>>; + + /// Schedule the task + fn schedule(&self, task: Notified<Self>); + + /// Schedule the task to run in the near future, yielding the thread to + /// other tasks. + fn yield_now(&self, task: Notified<Self>) { + self.schedule(task); + } +} + +cfg_rt! { + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task<T, S>( + task: T, + scheduler: S + ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let raw = RawTask::new::<T, S>(task, scheduler); + let task = Task { + raw, + _p: PhantomData, + }; + let notified = Notified(Task { + raw, + _p: PhantomData, + }); + let join = JoinHandle::new(raw); + + (task, notified, join) + } + + /// Creates a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks use this method. + pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = new_task(task, scheduler); + + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, + _p: PhantomData, + }; + std::mem::forget(task); + std::mem::forget(notified); + + (unowned, join) + } +} + +impl<S: 'static> Task<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task { + raw: RawTask::from_raw(ptr), + _p: PhantomData, + } + } + + fn header(&self) -> &Header { + self.raw.header() + } +} + +impl<S: 'static> Notified<S> { + fn header(&self) -> &Header { + self.0.header() + } +} + +cfg_rt_multi_thread! { + impl<S: 'static> Notified<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { + Notified(Task::from_raw(ptr)) + } + } + + impl<S: 'static> Task<S> { + fn into_raw(self) -> NonNull<Header> { + let ret = self.raw.header_ptr(); + mem::forget(self); + ret + } + } + + impl<S: 'static> Notified<S> { + fn into_raw(self) -> NonNull<Header> { + self.0.into_raw() + } + } +} + +impl<S: Schedule> Task<S> { + /// Pre-emptively cancels the task as part of the shutdown process. + pub(crate) fn shutdown(self) { + let raw = self.raw; + mem::forget(self); + raw.shutdown(); + } +} + +impl<S: Schedule> LocalNotified<S> { + /// Runs the task. + pub(crate) fn run(self) { + let raw = self.task.raw; + mem::forget(self); + raw.poll(); + } +} + +impl<S: Schedule> UnownedTask<S> { + // Used in test of the inject queue. + #[cfg(test)] + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + pub(super) fn into_notified(self) -> Notified<S> { + Notified(self.into_task()) + } + + fn into_task(self) -> Task<S> { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + let raw = self.raw; + mem::forget(self); + + // Transfer one ref-count to a Task object. + let task = Task::<S> { + raw, + _p: PhantomData, + }; + + // Use the other ref-count to poll the task. + raw.poll(); + // Decrement our extra ref-count + drop(task); + } + + pub(crate) fn shutdown(self) { + self.into_task().shutdown() + } +} + +impl<S: 'static> Drop for Task<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.header().state.ref_dec() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S: 'static> Drop for UnownedTask<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S> fmt::Debug for Task<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "Task({:p})", self.header()) + } +} + +impl<S> fmt::Debug for Notified<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "task::Notified({:p})", self.0.header()) + } +} + +/// # Safety +/// +/// Tasks are pinned. +unsafe impl<S> linked_list::Link for Task<S> { + type Handle = Task<S>; + type Target = Header; + + fn as_raw(handle: &Task<S>) -> NonNull<Header> { + handle.raw.header_ptr() + } + + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task::from_raw(ptr) + } + + unsafe fn pointers(target: NonNull<Header>) -> NonNull<linked_list::Pointers<Header>> { + // Not super great as it avoids some of looms checking... + NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/raw.rs b/third_party/rust/tokio/src/runtime/task/raw.rs new file mode 100644 index 0000000000..2e4420b5c1 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/raw.rs @@ -0,0 +1,165 @@ +use crate::future::Future; +use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; + +use std::ptr::NonNull; +use std::task::{Poll, Waker}; + +/// Raw task handle +pub(super) struct RawTask { + ptr: NonNull<Header>, +} + +pub(super) struct Vtable { + /// Polls the future. + pub(super) poll: unsafe fn(NonNull<Header>), + + /// Deallocates the memory. + pub(super) dealloc: unsafe fn(NonNull<Header>), + + /// Reads the task output, if complete. + pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker), + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool, + + /// The join handle has been dropped. + pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), + + /// The task is remotely aborted. + pub(super) remote_abort: unsafe fn(NonNull<Header>), + + /// Scheduler is being shutdown. + pub(super) shutdown: unsafe fn(NonNull<Header>), +} + +/// Get the vtable for the requested `T` and `S` generics. +pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { + &Vtable { + poll: poll::<T, S>, + dealloc: dealloc::<T, S>, + try_read_output: try_read_output::<T, S>, + try_set_join_waker: try_set_join_waker::<T, S>, + drop_join_handle_slow: drop_join_handle_slow::<T, S>, + remote_abort: remote_abort::<T, S>, + shutdown: shutdown::<T, S>, + } +} + +impl RawTask { + pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask + where + T: Future, + S: Schedule, + { + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; + + RawTask { ptr } + } + + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask { + RawTask { ptr } + } + + pub(super) fn header_ptr(&self) -> NonNull<Header> { + self.ptr + } + + /// Returns a reference to the task's meta structure. + /// + /// Safe as `Header` is `Sync`. + pub(super) fn header(&self) -> &Header { + unsafe { self.ptr.as_ref() } + } + + /// Safety: mutual exclusion is required to call this function. + pub(super) fn poll(self) { + let vtable = self.header().vtable; + unsafe { (vtable.poll)(self.ptr) } + } + + pub(super) fn dealloc(self) { + let vtable = self.header().vtable; + unsafe { + (vtable.dealloc)(self.ptr); + } + } + + /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T` + /// is the future stored by the task. + pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) { + let vtable = self.header().vtable; + (vtable.try_read_output)(self.ptr, dst, waker); + } + + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + let vtable = self.header().vtable; + unsafe { (vtable.try_set_join_waker)(self.ptr, waker) } + } + + pub(super) fn drop_join_handle_slow(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_join_handle_slow)(self.ptr) } + } + + pub(super) fn shutdown(self) { + let vtable = self.header().vtable; + unsafe { (vtable.shutdown)(self.ptr) } + } + + pub(super) fn remote_abort(self) { + let vtable = self.header().vtable; + unsafe { (vtable.remote_abort)(self.ptr) } + } +} + +impl Clone for RawTask { + fn clone(&self) -> Self { + RawTask { ptr: self.ptr } + } +} + +impl Copy for RawTask {} + +unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.poll(); +} + +unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.dealloc(); +} + +unsafe fn try_read_output<T: Future, S: Schedule>( + ptr: NonNull<Header>, + dst: *mut (), + waker: &Waker, +) { + let out = &mut *(dst as *mut Poll<super::Result<T::Output>>); + + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_read_output(out, waker); +} + +unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool { + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_set_join_waker(waker) +} + +unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.drop_join_handle_slow() +} + +unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.remote_abort() +} + +unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.shutdown() +} diff --git a/third_party/rust/tokio/src/runtime/task/state.rs b/third_party/rust/tokio/src/runtime/task/state.rs new file mode 100644 index 0000000000..c2d5b28eac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/state.rs @@ -0,0 +1,595 @@ +use crate::loom::sync::atomic::AtomicUsize; + +use std::fmt; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::usize; + +pub(super) struct State { + val: AtomicUsize, +} + +/// Current state value. +#[derive(Copy, Clone)] +pub(super) struct Snapshot(usize); + +type UpdateResult = Result<Snapshot, Snapshot>; + +/// The task is currently being run. +const RUNNING: usize = 0b0001; + +/// The task is complete. +/// +/// Once this bit is set, it is never unset. +const COMPLETE: usize = 0b0010; + +/// Extracts the task's lifecycle value from the state. +const LIFECYCLE_MASK: usize = 0b11; + +/// Flag tracking if the task has been pushed into a run queue. +const NOTIFIED: usize = 0b100; + +/// The join handle is still around. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_INTEREST: usize = 0b1_000; + +/// A join handle waker has been set. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_WAKER: usize = 0b10_000; + +/// The task has been forcibly cancelled. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const CANCELLED: usize = 0b100_000; + +/// All bits. +const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; + +/// Bits used by the ref count portion of the state. +const REF_COUNT_MASK: usize = !STATE_MASK; + +/// Number of positions to shift the ref count. +const REF_COUNT_SHIFT: usize = REF_COUNT_MASK.count_zeros() as usize; + +/// One ref count. +const REF_ONE: usize = 1 << REF_COUNT_SHIFT; + +/// State a task is initialized with. +/// +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; + +#[must_use] +pub(super) enum TransitionToRunning { + Success, + Cancelled, + Failed, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToIdle { + Ok, + OkNotified, + OkDealloc, + Cancelled, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByVal { + DoNothing, + Submit, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByRef { + DoNothing, + Submit, +} + +/// All transitions are performed via RMW operations. This establishes an +/// unambiguous modification order. +impl State { + /// Returns a task's initial state. + pub(super) fn new() -> State { + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. + State { + val: AtomicUsize::new(INITIAL_STATE), + } + } + + /// Loads the current state, establishes `Acquire` ordering. + pub(super) fn load(&self) -> Snapshot { + Snapshot(self.val.load(Acquire)) + } + + /// Attempts to transition the lifecycle to `Running`. This sets the + /// notified bit to false so notifications during the poll can be detected. + pub(super) fn transition_to_running(&self) -> TransitionToRunning { + self.fetch_update_action(|mut next| { + let action; + assert!(next.is_notified()); + + if !next.is_idle() { + // This happens if the task is either currently running or if it + // has already completed, e.g. if it was cancelled during + // shutdown. Consume the ref-count and return. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToRunning::Dealloc; + } else { + action = TransitionToRunning::Failed; + } + } else { + // We are able to lock the RUNNING bit. + next.set_running(); + next.unset_notified(); + + if next.is_cancelled() { + action = TransitionToRunning::Cancelled; + } else { + action = TransitionToRunning::Success; + } + } + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Idle`. + /// + /// Returns `true` if the transition to `Idle` is successful, `false` otherwise. + /// The transition to `Idle` fails if the task has been flagged to be + /// cancelled. + pub(super) fn transition_to_idle(&self) -> TransitionToIdle { + self.fetch_update_action(|curr| { + assert!(curr.is_running()); + + if curr.is_cancelled() { + return (TransitionToIdle::Cancelled, None); + } + + let mut next = curr; + let action; + next.unset_running(); + + if !next.is_notified() { + // Polling the future consumes the ref-count of the Notified. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToIdle::OkDealloc; + } else { + action = TransitionToIdle::Ok; + } + } else { + // The caller will schedule a new notification, so we create a + // new ref-count for the notification. Our own ref-count is kept + // for now, and the caller will drop it shortly. + next.ref_inc(); + action = TransitionToIdle::OkNotified; + } + + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Complete`. + pub(super) fn transition_to_complete(&self) -> Snapshot { + const DELTA: usize = RUNNING | COMPLETE; + + let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); + assert!(prev.is_running()); + assert!(!prev.is_complete()); + + Snapshot(prev.0 ^ DELTA) + } + + /// Transitions from `Complete` -> `Terminal`, decrementing the reference + /// count the specified number of times. + /// + /// Returns true if the task should be deallocated. + pub(super) fn transition_to_terminal(&self, count: usize) -> bool { + let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); + assert!( + prev.ref_count() >= count, + "current: {}, sub: {}", + prev.ref_count(), + count + ); + prev.ref_count() == count + } + + /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. + pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { + self.fetch_update_action(|mut snapshot| { + let action; + + if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit anything as the thread currently running the + // future is responsible for that. + snapshot.set_notified(); + snapshot.ref_dec(); + + // The thread that set the running bit also holds a ref-count. + assert!(snapshot.ref_count() > 0); + + action = TransitionToNotifiedByVal::DoNothing; + } else if snapshot.is_complete() || snapshot.is_notified() { + // We do not need to submit any notifications, but we have to + // decrement the ref-count. + snapshot.ref_dec(); + + if snapshot.ref_count() == 0 { + action = TransitionToNotifiedByVal::Dealloc; + } else { + action = TransitionToNotifiedByVal::DoNothing; + } + } else { + // We create a new notified that we can submit. The caller + // retains ownership of the ref-count they passed in. + snapshot.set_notified(); + snapshot.ref_inc(); + action = TransitionToNotifiedByVal::Submit; + } + + (action, Some(snapshot)) + }) + } + + /// Transitions the state to `NOTIFIED`. + pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_complete() || snapshot.is_notified() { + // There is nothing to do in this case. + (TransitionToNotifiedByRef::DoNothing, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit as the thread currently running the future is + // responsible for that. + snapshot.set_notified(); + (TransitionToNotifiedByRef::DoNothing, Some(snapshot)) + } else { + // The task is idle and not notified. We should submit a + // notification. + snapshot.set_notified(); + snapshot.ref_inc(); + (TransitionToNotifiedByRef::Submit, Some(snapshot)) + } + }) + } + + /// Sets the cancelled bit and transitions the state to `NOTIFIED` if idle. + /// + /// Returns `true` if the task needs to be submitted to the pool for + /// execution. + pub(super) fn transition_to_notified_and_cancel(&self) -> bool { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_cancelled() || snapshot.is_complete() { + // Aborts to completed or cancelled tasks are no-ops. + (false, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as cancelled. The thread + // running the task will notice the cancelled bit when it + // stops polling and it will kill the task. + // + // The set_notified() call is not strictly necessary but it will + // in some cases let a wake_by_ref call return without having + // to perform a compare_exchange. + snapshot.set_notified(); + snapshot.set_cancelled(); + (false, Some(snapshot)) + } else { + // The task is idle. We set the cancelled and notified bits and + // submit a notification if the notified bit was not already + // set. + snapshot.set_cancelled(); + if !snapshot.is_notified() { + snapshot.set_notified(); + snapshot.ref_inc(); + (true, Some(snapshot)) + } else { + (false, Some(snapshot)) + } + } + }) + } + + /// Sets the `CANCELLED` bit and attempts to transition to `Running`. + /// + /// Returns `true` if the transition to `Running` succeeded. + pub(super) fn transition_to_shutdown(&self) -> bool { + let mut prev = Snapshot(0); + + let _ = self.fetch_update(|mut snapshot| { + prev = snapshot; + + if snapshot.is_idle() { + snapshot.set_running(); + } + + // If the task was not idle, the thread currently running the task + // will notice the cancelled bit and cancel it once the poll + // completes. + snapshot.set_cancelled(); + Some(snapshot) + }); + + prev.is_idle() + } + + /// Optimistically tries to swap the state assuming the join handle is + /// __immediately__ dropped on spawn. + pub(super) fn drop_join_handle_fast(&self) -> Result<(), ()> { + use std::sync::atomic::Ordering::Relaxed; + + // Relaxed is acceptable as if this function is called and succeeds, + // then nothing has been done w/ the join handle. + // + // The moment the join handle is used (polled), the `JOIN_WAKER` flag is + // set, at which point the CAS will fail. + // + // Given this, there is no risk if this operation is reordered. + self.val + .compare_exchange_weak( + INITIAL_STATE, + (INITIAL_STATE - REF_ONE) & !JOIN_INTEREST, + Release, + Relaxed, + ) + .map(|_| ()) + .map_err(|_| ()) + } + + /// Tries to unset the JOIN_INTEREST flag. + /// + /// Returns `Ok` if the operation happens before the task transitions to a + /// completed state, `Err` otherwise. + pub(super) fn unset_join_interested(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_interested(); + + Some(next) + }) + } + + /// Sets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn set_join_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(!curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.set_join_waker(); + + Some(next) + }) + } + + /// Unsets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` has been unset, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn unset_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_waker(); + + Some(next) + }) + } + + pub(super) fn ref_inc(&self) { + use std::process; + use std::sync::atomic::Ordering::Relaxed; + + // Using a relaxed ordering is alright here, as knowledge of the + // original reference prevents other threads from erroneously deleting + // the object. + // + // As explained in the [Boost documentation][1], Increasing the + // reference counter can always be done with memory_order_relaxed: New + // references to an object can only be formed from an existing + // reference, and passing an existing reference from one thread to + // another must already provide any required synchronization. + // + // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html) + let prev = self.val.fetch_add(REF_ONE, Relaxed); + + // If the reference count overflowed, abort. + if prev > isize::MAX as usize { + process::abort(); + } + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 1); + prev.ref_count() == 1 + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 2); + prev.ref_count() == 2 + } + + fn fetch_update_action<F, T>(&self, mut f: F) -> T + where + F: FnMut(Snapshot) -> (T, Option<Snapshot>), + { + let mut curr = self.load(); + + loop { + let (output, next) = f(curr); + let next = match next { + Some(next) => next, + None => return output, + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return output, + Err(actual) => curr = Snapshot(actual), + } + } + } + + fn fetch_update<F>(&self, mut f: F) -> Result<Snapshot, Snapshot> + where + F: FnMut(Snapshot) -> Option<Snapshot>, + { + let mut curr = self.load(); + + loop { + let next = match f(curr) { + Some(next) => next, + None => return Err(curr), + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return Ok(next), + Err(actual) => curr = Snapshot(actual), + } + } + } +} + +// ===== impl Snapshot ===== + +impl Snapshot { + /// Returns `true` if the task is in an idle state. + pub(super) fn is_idle(self) -> bool { + self.0 & (RUNNING | COMPLETE) == 0 + } + + /// Returns `true` if the task has been flagged as notified. + pub(super) fn is_notified(self) -> bool { + self.0 & NOTIFIED == NOTIFIED + } + + fn unset_notified(&mut self) { + self.0 &= !NOTIFIED + } + + fn set_notified(&mut self) { + self.0 |= NOTIFIED + } + + pub(super) fn is_running(self) -> bool { + self.0 & RUNNING == RUNNING + } + + fn set_running(&mut self) { + self.0 |= RUNNING; + } + + fn unset_running(&mut self) { + self.0 &= !RUNNING; + } + + pub(super) fn is_cancelled(self) -> bool { + self.0 & CANCELLED == CANCELLED + } + + fn set_cancelled(&mut self) { + self.0 |= CANCELLED; + } + + /// Returns `true` if the task's future has completed execution. + pub(super) fn is_complete(self) -> bool { + self.0 & COMPLETE == COMPLETE + } + + pub(super) fn is_join_interested(self) -> bool { + self.0 & JOIN_INTEREST == JOIN_INTEREST + } + + fn unset_join_interested(&mut self) { + self.0 &= !JOIN_INTEREST + } + + pub(super) fn has_join_waker(self) -> bool { + self.0 & JOIN_WAKER == JOIN_WAKER + } + + fn set_join_waker(&mut self) { + self.0 |= JOIN_WAKER; + } + + fn unset_join_waker(&mut self) { + self.0 &= !JOIN_WAKER + } + + pub(super) fn ref_count(self) -> usize { + (self.0 & REF_COUNT_MASK) >> REF_COUNT_SHIFT + } + + fn ref_inc(&mut self) { + assert!(self.0 <= isize::MAX as usize); + self.0 += REF_ONE; + } + + pub(super) fn ref_dec(&mut self) { + assert!(self.ref_count() > 0); + self.0 -= REF_ONE + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let snapshot = self.load(); + snapshot.fmt(fmt) + } +} + +impl fmt::Debug for Snapshot { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Snapshot") + .field("is_running", &self.is_running()) + .field("is_complete", &self.is_complete()) + .field("is_notified", &self.is_notified()) + .field("is_cancelled", &self.is_cancelled()) + .field("is_join_interested", &self.is_join_interested()) + .field("has_join_waker", &self.has_join_waker()) + .field("ref_count", &self.ref_count()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/waker.rs b/third_party/rust/tokio/src/runtime/task/waker.rs new file mode 100644 index 0000000000..74a29f4a84 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/waker.rs @@ -0,0 +1,130 @@ +use crate::future::Future; +use crate::runtime::task::harness::Harness; +use crate::runtime::task::{Header, Schedule}; + +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops; +use std::ptr::NonNull; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +pub(super) struct WakerRef<'a, S: 'static> { + waker: ManuallyDrop<Waker>, + _p: PhantomData<(&'a Header, S)>, +} + +/// Returns a `WakerRef` which avoids having to pre-emptively increase the +/// refcount if there is no need to do so. +pub(super) fn waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S> +where + T: Future, + S: Schedule, +{ + // `Waker::will_wake` uses the VTABLE pointer as part of the check. This + // means that `will_wake` will always return false when using the current + // task's waker. (discussion at rust-lang/rust#66281). + // + // To fix this, we use a single vtable. Since we pass in a reference at this + // point and not an *owned* waker, we must ensure that `drop` is never + // called on this waker instance. This is done by wrapping it with + // `ManuallyDrop` and then never calling drop. + let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T, S>(*header))) }; + + WakerRef { + waker, + _p: PhantomData, + } +} + +impl<S> ops::Deref for WakerRef<'_, S> { + type Target = Waker; + + fn deref(&self) -> &Waker { + &self.waker + } +} + +cfg_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + if let Some(id) = $harness.id() { + tracing::trace!( + target: "tokio::task::waker", + op = $op, + task.id = id.into_u64(), + ); + } + } + } +} + +cfg_not_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + // noop + let _ = &$harness; + } + } +} + +unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker +where + T: Future, + S: Schedule, +{ + let header = ptr as *const Header; + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.clone"); + (*header).state.ref_inc(); + raw_waker::<T, S>(ptr) +} + +unsafe fn drop_waker<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.drop"); + harness.drop_reference(); +} + +unsafe fn wake_by_val<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake"); + harness.wake_by_val(); +} + +// Wake without consuming the waker +unsafe fn wake_by_ref<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake_by_ref"); + harness.wake_by_ref(); +} + +fn raw_waker<T, S>(header: NonNull<Header>) -> RawWaker +where + T: Future, + S: Schedule, +{ + let ptr = header.as_ptr() as *const (); + let vtable = &RawWakerVTable::new( + clone_waker::<T, S>, + wake_by_val::<T, S>, + wake_by_ref::<T, S>, + drop_waker::<T, S>, + ); + RawWaker::new(ptr, vtable) +} |