diff options
Diffstat (limited to 'third_party/rust/tokio/src/sync')
39 files changed, 14859 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/sync/barrier.rs b/third_party/rust/tokio/src/sync/barrier.rs new file mode 100644 index 0000000000..dfc76a40eb --- /dev/null +++ b/third_party/rust/tokio/src/sync/barrier.rs @@ -0,0 +1,206 @@ +use crate::loom::sync::Mutex; +use crate::sync::watch; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +/// A barrier enables multiple tasks to synchronize the beginning of some computation. +/// +/// ``` +/// # #[tokio::main] +/// # async fn main() { +/// use tokio::sync::Barrier; +/// use std::sync::Arc; +/// +/// let mut handles = Vec::with_capacity(10); +/// let barrier = Arc::new(Barrier::new(10)); +/// for _ in 0..10 { +/// let c = barrier.clone(); +/// // The same messages will be printed together. +/// // You will NOT see any interleaving. +/// handles.push(tokio::spawn(async move { +/// println!("before wait"); +/// let wait_result = c.wait().await; +/// println!("after wait"); +/// wait_result +/// })); +/// } +/// +/// // Will not resolve until all "after wait" messages have been printed +/// let mut num_leaders = 0; +/// for handle in handles { +/// let wait_result = handle.await.unwrap(); +/// if wait_result.is_leader() { +/// num_leaders += 1; +/// } +/// } +/// +/// // Exactly one barrier will resolve as the "leader" +/// assert_eq!(num_leaders, 1); +/// # } +/// ``` +#[derive(Debug)] +pub struct Barrier { + state: Mutex<BarrierState>, + wait: watch::Receiver<usize>, + n: usize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +#[derive(Debug)] +struct BarrierState { + waker: watch::Sender<usize>, + arrived: usize, + generation: usize, +} + +impl Barrier { + /// Creates a new barrier that can block a given number of tasks. + /// + /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all + /// tasks at once when the `n`th task calls `wait`. + #[track_caller] + pub fn new(mut n: usize) -> Barrier { + let (waker, wait) = crate::sync::watch::channel(0); + + if n == 0 { + // if n is 0, it's not clear what behavior the user wants. + // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every + // .wait() immediately unblocks, so we adopt that here as well. + n = 1; + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Barrier", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + size = n, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 0, + ) + }); + resource_span + }; + + Barrier { + state: Mutex::new(BarrierState { + waker, + arrived: 0, + generation: 1, + }), + n, + wait, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span, + } + } + + /// Does not resolve until all tasks have rendezvoused here. + /// + /// Barriers are re-usable after all tasks have rendezvoused once, and can + /// be used continuously. + /// + /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from + /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks + /// will receive a result that will return `false` from `is_leader`. + pub async fn wait(&self) -> BarrierWaitResult { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace::async_op( + || self.wait_internal(), + self.resource_span.clone(), + "Barrier::wait", + "poll", + false, + ) + .await; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return self.wait_internal().await; + } + async fn wait_internal(&self) -> BarrierWaitResult { + // NOTE: we are taking a _synchronous_ lock here. + // It is okay to do so because the critical section is fast and never yields, so it cannot + // deadlock even if another future is concurrently holding the lock. + // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than + // the asynchronous counter-parts, so we should use them where possible [citation needed]. + // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across + // a yield point, and thus marks the returned future as !Send. + let generation = { + let mut state = self.state.lock(); + let generation = state.generation; + state.arrived += 1; + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 1, + arrived.op = "add", + ); + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + arrived = true, + ); + if state.arrived == self.n { + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + is_leader = true, + ); + // we are the leader for this generation + // wake everyone, increment the generation, and return + state + .waker + .send(state.generation) + .expect("there is at least one receiver"); + state.arrived = 0; + state.generation += 1; + return BarrierWaitResult(true); + } + + generation + }; + + // we're going to have to wait for the last of the generation to arrive + let mut wait = self.wait.clone(); + + loop { + let _ = wait.changed().await; + + // note that the first time through the loop, this _will_ yield a generation + // immediately, since we cloned a receiver that has never seen any values. + if *wait.borrow() >= generation { + break; + } + } + + BarrierWaitResult(false) + } +} + +/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. +#[derive(Debug, Clone)] +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + /// Returns `true` if this task from wait is the "leader task". + /// + /// Only one task will have `true` returned from their result, all other tasks will have + /// `false` returned. + pub fn is_leader(&self) -> bool { + self.0 + } +} diff --git a/third_party/rust/tokio/src/sync/batch_semaphore.rs b/third_party/rust/tokio/src/sync/batch_semaphore.rs new file mode 100644 index 0000000000..4f5effff31 --- /dev/null +++ b/third_party/rust/tokio/src/sync/batch_semaphore.rs @@ -0,0 +1,727 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] +//! # Implementation Details. +//! +//! The semaphore is implemented using an intrusive linked list of waiters. An +//! atomic counter tracks the number of available permits. If the semaphore does +//! not contain the required number of permits, the task attempting to acquire +//! permits places its waker at the end of a queue. When new permits are made +//! available (such as by releasing an initial acquisition), they are assigned +//! to the task at the front of the queue, waking that task if its requested +//! number of permits is met. +//! +//! Because waiters are enqueued at the back of the linked list and dequeued +//! from the front, the semaphore is fair. Tasks trying to acquire large numbers +//! of permits at a time will always be woken eventually, even if many other +//! tasks are acquiring smaller numbers of permits. This means that in a +//! use-case like tokio's read-write lock, writers will not be starved by +//! readers. +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Mutex, MutexGuard}; +use crate::util::linked_list::{self, LinkedList}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use crate::util::WakeList; + +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::*; +use std::task::Poll::*; +use std::task::{Context, Poll, Waker}; +use std::{cmp, fmt}; + +/// An asynchronous counting semaphore which permits waiting on multiple permits at once. +pub(crate) struct Semaphore { + waiters: Mutex<Waitlist>, + /// The current number of available permits in the semaphore. + permits: AtomicUsize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +struct Waitlist { + queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, + closed: bool, +} + +/// Error returned from the [`Semaphore::try_acquire`] function. +/// +/// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire +#[derive(Debug, PartialEq)] +pub enum TryAcquireError { + /// The semaphore has been [closed] and cannot issue new permits. + /// + /// [closed]: crate::sync::Semaphore::close + Closed, + + /// The semaphore has no available permits. + NoPermits, +} +/// Error returned from the [`Semaphore::acquire`] function. +/// +/// An `acquire` operation can only fail if the semaphore has been +/// [closed]. +/// +/// [closed]: crate::sync::Semaphore::close +/// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire +#[derive(Debug)] +pub struct AcquireError(()); + +pub(crate) struct Acquire<'a> { + node: Waiter, + semaphore: &'a Semaphore, + num_permits: u32, + queued: bool, +} + +/// An entry in the wait queue. +struct Waiter { + /// The current state of the waiter. + /// + /// This is either the number of remaining permits required by + /// the waiter, or a flag indicating that the waiter is not yet queued. + state: AtomicUsize, + + /// The waker to notify the task awaiting permits. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + waker: UnsafeCell<Option<Waker>>, + + /// Intrusive linked-list pointers. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + /// + /// TODO: Ideally, we would be able to use loom to enforce that + /// this isn't accessed concurrently. However, it is difficult to + /// use a `UnsafeCell` here, since the `Link` trait requires _returning_ + /// references to `Pointers`, and `UnsafeCell` requires that checked access + /// take place inside a closure. We should consider changing `Pointers` to + /// use `UnsafeCell` internally. + pointers: linked_list::Pointers<Waiter>, + + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx: trace::AsyncOpTracingCtx, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Semaphore { + /// The maximum number of permits which a semaphore can hold. + /// + /// Note that this reserves three bits of flags in the permit counter, but + /// we only actually use one of them. However, the previous semaphore + /// implementation used three bits, so we will continue to reserve them to + /// avoid a breaking change if additional flags need to be added in the + /// future. + pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; + const CLOSED: usize = 1; + // The least-significant bit in the number of permits is reserved to use + // as a flag indicating that the semaphore has been closed. Consequently + // PERMIT_SHIFT is used to leave that bit for that purpose. + const PERMIT_SHIFT: usize = 1; + + /// Creates a new semaphore with the initial number of permits + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + pub(crate) fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + is_internal = true + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = permits, + permits.op = "override", + ) + }); + resource_span + }; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new semaphore with the initial number of permits. + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + /// + /// If the specified number of permits exceeds the maximum permit amount + /// Then the value will get clamped to the maximum number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + pub(crate) const fn const_new(mut permits: usize) -> Self { + // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925 + // currently we just clamp the permit count when it exceeds the max + permits &= Self::MAX_PERMITS; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::const_new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Returns the current number of available permits. + pub(crate) fn available_permits(&self) -> usize { + self.permits.load(Acquire) >> Self::PERMIT_SHIFT + } + + /// Adds `added` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + pub(crate) fn release(&self, added: usize) { + if added == 0 { + return; + } + + // Assign permits to the wait queue + self.add_permits_locked(added, self.waiters.lock()); + } + + /// Closes the semaphore. This prevents the semaphore from issuing new + /// permits and notifies all pending waiters. + pub(crate) fn close(&self) { + let mut waiters = self.waiters.lock(); + // If the semaphore's permits counter has enough permits for an + // unqueued waiter to acquire all the permits it needs immediately, + // it won't touch the wait list. Therefore, we have to set a bit on + // the permit counter as well. However, we must do this while + // holding the lock --- otherwise, if we set the bit and then wait + // to acquire the lock we'll enter an inconsistent state where the + // permit counter is closed, but the wait list is not. + self.permits.fetch_or(Self::CLOSED, Release); + waiters.closed = true; + while let Some(mut waiter) = waiters.queue.pop_back() { + let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = waker { + waker.wake(); + } + } + } + + /// Returns true if the semaphore is closed. + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { + assert!( + num_permits as usize <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + let mut curr = self.permits.load(Acquire); + loop { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { + return Err(TryAcquireError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits { + return Err(TryAcquireError::NoPermits); + } + + let next = curr - num_permits; + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + // TODO: Instrument once issue has been solved} + return Ok(()); + } + Err(actual) => curr = actual, + } + } + } + + pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> { + Acquire::new(self, num_permits) + } + + /// Release `rem` permits to the semaphore's wait list, starting from the + /// end of the queue. + /// + /// If `rem` exceeds the number of permits needed by the wait list, the + /// remainder are assigned back to the semaphore. + fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { + let mut wakers = WakeList::new(); + let mut lock = Some(waiters); + let mut is_empty = false; + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); + 'inner: while wakers.can_push() { + // Was the waiter assigned enough permits to wake it? + match waiters.queue.last() { + Some(waiter) => { + if !waiter.assign_permits(&mut rem) { + break 'inner; + } + } + None => { + is_empty = true; + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + break 'inner; + } + }; + let mut waiter = waiters.queue.pop_back().unwrap(); + if let Some(waker) = + unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } + { + wakers.push(waker); + } + } + + if rem > 0 && is_empty { + let permits = rem; + assert!( + permits <= Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); + let prev = prev >> Self::PERMIT_SHIFT; + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + + // add remaining permits back + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = rem, + permits.op = "add", + ) + }); + + rem = 0; + } + + drop(waiters); // release the lock + + wakers.wake_all(); + } + + assert_eq!(rem, 0); + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: u32, + node: Pin<&mut Waiter>, + queued: bool, + ) -> Poll<Result<(), AcquireError>> { + let mut acquired = 0; + + let needed = if queued { + node.state.load(Acquire) << Self::PERMIT_SHIFT + } else { + (num_permits as usize) << Self::PERMIT_SHIFT + }; + + let mut lock = None; + // First, try to take the requested number of permits from the + // semaphore. + let mut curr = self.permits.load(Acquire); + let mut waiters = loop { + // Has the semaphore closed? + if curr & Self::CLOSED > 0 { + return Ready(Err(AcquireError::closed())); + } + + let mut remaining = 0; + let total = curr + .checked_add(acquired) + .expect("number of permits must not overflow"); + let (next, acq) = if total >= needed { + let next = curr - (needed - acquired); + (next, needed >> Self::PERMIT_SHIFT) + } else { + remaining = (needed - acquired) - curr; + (0, curr >> Self::PERMIT_SHIFT) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(self.waiters.lock()); + } + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + acquired += acq; + if remaining == 0 { + if !queued { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ); + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = acquired, + permits.op = "add", + ) + }); + + return Ready(Ok(())); + } else if lock.is_none() { + break self.waiters.lock(); + } + } + break lock.expect("lock must be acquired before waiting"); + } + Err(actual) => curr = actual, + } + }; + + if waiters.closed { + return Ready(Err(AcquireError::closed())); + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ) + }); + + if node.assign_permits(&mut acquired) { + self.add_permits_locked(acquired, waiters); + return Ready(Ok(())); + } + + assert_eq!(acquired, 0); + + // Otherwise, register the waker & enqueue the node. + node.waker.with_mut(|waker| { + // Safety: the wait list is locked, so we may modify the waker. + let waker = unsafe { &mut *waker }; + // Do we need to register the new waker? + if waker + .as_ref() + .map(|waker| !waker.will_wake(cx.waker())) + .unwrap_or(true) + { + *waker = Some(cx.waker().clone()); + } + }); + + // If the waiter is not already in the wait queue, enqueue it. + if !queued { + let node = unsafe { + let node = Pin::into_inner_unchecked(node) as *mut _; + NonNull::new_unchecked(node) + }; + + waiters.queue.push_front(node); + } + + Pending + } +} + +impl fmt::Debug for Semaphore { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Semaphore") + .field("permits", &self.available_permits()) + .finish() + } +} + +impl Waiter { + fn new( + num_permits: u32, + #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx, + ) -> Self { + Waiter { + waker: UnsafeCell::new(None), + state: AtomicUsize::new(num_permits as usize), + pointers: linked_list::Pointers::new(), + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx, + _p: PhantomPinned, + } + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize) -> bool { + let mut curr = self.state.load(Acquire); + loop { + let assign = cmp::min(curr, *n); + let next = curr - assign; + match self.state.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + *n -= assign; + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.ctx.async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = assign, + permits.op = "add", + ); + }); + return next == 0; + } + Err(actual) => curr = actual, + } + } + } +} + +impl Future for Acquire<'_> { + type Output = Result<(), AcquireError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _resource_span = self.node.ctx.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_span = self.node.ctx.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered(); + + let (node, semaphore, needed, queued) = self.project(); + + // First, ensure the current task has enough budget to proceed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let coop = ready!(trace_poll_op!( + "poll_acquire", + crate::coop::poll_proceed(cx), + )); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let coop = ready!(crate::coop::poll_proceed(cx)); + + let result = match semaphore.poll_acquire(cx, needed, node, *queued) { + Pending => { + *queued = true; + Pending + } + Ready(r) => { + coop.made_progress(); + r?; + *queued = false; + Ready(Ok(())) + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace_poll_op!("poll_acquire", result); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + return result; + } +} + +impl<'a> Acquire<'a> { + fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self { + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { + node: Waiter::new(num_permits), + semaphore, + num_permits, + queued: false, + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return semaphore.resource_span.in_scope(|| { + let async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Acquire::new"); + let async_op_poll_span = async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_requested = num_permits, + permits.op = "override", + ); + + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = 0 as usize, + permits.op = "override", + ); + + tracing::trace_span!("runtime.resource.async_op.poll") + }); + + let ctx = trace::AsyncOpTracingCtx { + async_op_span, + async_op_poll_span, + resource_span: semaphore.resource_span.clone(), + }; + + Self { + node: Waiter::new(num_permits, ctx), + semaphore, + num_permits, + queued: false, + } + }); + } + + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) { + fn is_unpin<T: Unpin>() {} + unsafe { + // Safety: all fields other than `node` are `Unpin` + + is_unpin::<&Semaphore>(); + is_unpin::<&mut bool>(); + is_unpin::<u32>(); + + let this = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut this.node), + this.semaphore, + this.num_permits, + &mut this.queued, + ) + } + } +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + // If the future is completed, there is no node in the wait list, so we + // can skip acquiring the lock. + if !self.queued { + return; + } + + // This is where we ensure safety. The future is being dropped, + // which means we must ensure that the waiter entry is no longer stored + // in the linked list. + let mut waiters = self.semaphore.waiters.lock(); + + // remove the entry from the list + let node = NonNull::from(&mut self.node); + // Safety: we have locked the wait list. + unsafe { waiters.queue.remove(node) }; + + let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); + if acquired_permits > 0 { + self.semaphore.add_permits_locked(acquired_permits, waiters); + } + } +} + +// Safety: the `Acquire` future is not `Sync` automatically because it contains +// a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the +// `UnsafeCell` is only accessed when the future is borrowed mutably (either in +// `poll` or in `drop`). Therefore, it is safe (although not particularly +// _useful_) for the future to be borrowed immutably across threads. +unsafe impl Sync for Acquire<'_> {} + +// ===== impl AcquireError ==== + +impl AcquireError { + fn closed() -> AcquireError { + AcquireError(()) + } +} + +impl fmt::Display for AcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "semaphore closed") + } +} + +impl std::error::Error for AcquireError {} + +// ===== impl TryAcquireError ===== + +impl TryAcquireError { + /// Returns `true` if the error was caused by a closed semaphore. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_closed(&self) -> bool { + matches!(self, TryAcquireError::Closed) + } + + /// Returns `true` if the error was caused by calling `try_acquire` on a + /// semaphore with no available permits. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_no_permits(&self) -> bool { + matches!(self, TryAcquireError::NoPermits) + } +} + +impl fmt::Display for TryAcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), + } + } +} + +impl std::error::Error for TryAcquireError {} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + // XXX: ideally, we would be able to use `Pin` here, to enforce the + // invariant that list entries may not move while in the list. However, we + // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` + // type would require `Semaphore` to be generic over a lifetime. We can't + // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether + // or not they dereference to an `!Unpin` target. + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} diff --git a/third_party/rust/tokio/src/sync/broadcast.rs b/third_party/rust/tokio/src/sync/broadcast.rs new file mode 100644 index 0000000000..0d9cd3bc17 --- /dev/null +++ b/third_party/rust/tokio/src/sync/broadcast.rs @@ -0,0 +1,1078 @@ +//! A multi-producer, multi-consumer broadcast queue. Each sent value is seen by +//! all consumers. +//! +//! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] +//! values. [`Sender`] handles are clone-able, allowing concurrent send and +//! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as +//! long as `T` is also `Send` or `Sync` respectively. +//! +//! When a value is sent, **all** [`Receiver`] handles are notified and will +//! receive the value. The value is stored once inside the channel and cloned on +//! demand for each receiver. Once all receivers have received a clone of the +//! value, the value is released from the channel. +//! +//! A channel is created by calling [`channel`], specifying the maximum number +//! of messages the channel can retain at any given time. +//! +//! New [`Receiver`] handles are created by calling [`Sender::subscribe`]. The +//! returned [`Receiver`] will receive values sent **after** the call to +//! `subscribe`. +//! +//! ## Lagging +//! +//! As sent messages must be retained until **all** [`Receiver`] handles receive +//! a clone, broadcast channels are susceptible to the "slow receiver" problem. +//! In this case, all but one receiver are able to receive values at the rate +//! they are sent. Because one receiver is stalled, the channel starts to fill +//! up. +//! +//! This broadcast channel implementation handles this case by setting a hard +//! upper bound on the number of values the channel may retain at any given +//! time. This upper bound is passed to the [`channel`] function as an argument. +//! +//! If a value is sent when the channel is at capacity, the oldest value +//! currently held by the channel is released. This frees up space for the new +//! value. Any receiver that has not yet seen the released value will return +//! [`RecvError::Lagged`] the next time [`recv`] is called. +//! +//! Once [`RecvError::Lagged`] is returned, the lagging receiver's position is +//! updated to the oldest value contained by the channel. The next call to +//! [`recv`] will return this value. +//! +//! This behavior enables a receiver to detect when it has lagged so far behind +//! that data has been dropped. The caller may decide how to respond to this: +//! either by aborting its task or by tolerating lost messages and resuming +//! consumption of the channel. +//! +//! ## Closing +//! +//! When **all** [`Sender`] handles have been dropped, no new values may be +//! sent. At this point, the channel is "closed". Once a receiver has received +//! all values retained by the channel, the next call to [`recv`] will return +//! with [`RecvError::Closed`]. +//! +//! [`Sender`]: crate::sync::broadcast::Sender +//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +//! [`Receiver`]: crate::sync::broadcast::Receiver +//! [`channel`]: crate::sync::broadcast::channel +//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged +//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed +//! [`recv`]: crate::sync::broadcast::Receiver::recv +//! +//! # Examples +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! Handling lag +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx) = broadcast::channel(2); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! tx.send(30).unwrap(); +//! +//! // The receiver lagged behind +//! assert!(rx.recv().await.is_err()); +//! +//! // At this point, we can abort or continue with lost messages +//! +//! assert_eq!(20, rx.recv().await.unwrap()); +//! assert_eq!(30, rx.recv().await.unwrap()); +//! } +//! ``` + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; + +use std::fmt; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; +use std::usize; + +/// Sending-half of the [`broadcast`] channel. +/// +/// May be used from many threads. Messages can be sent with +/// [`send`][Sender::send]. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Sender<T> { + shared: Arc<Shared<T>>, +} + +/// Receiving-half of the [`broadcast`] channel. +/// +/// Must not be used concurrently. Messages may be retrieved using +/// [`recv`][Receiver::recv]. +/// +/// To turn this receiver into a `Stream`, you can use the [`BroadcastStream`] +/// wrapper. +/// +/// [`BroadcastStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.BroadcastStream.html +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Receiver<T> { + /// State shared with all receivers and senders. + shared: Arc<Shared<T>>, + + /// Next position to read from + next: u64, +} + +pub mod error { + //! Broadcast error types + + use std::fmt; + + /// Error returned by from the [`send`] function on a [`Sender`]. + /// + /// A **send** operation can only fail if there are no active receivers, + /// implying that the message could never be received. The error contains the + /// message being sent as a payload so it can be recovered. + /// + /// [`send`]: crate::sync::broadcast::Sender::send + /// [`Sender`]: crate::sync::broadcast::Sender + #[derive(Debug)] + pub struct SendError<T>(pub T); + + impl<T> fmt::Display for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// An error returned from the [`recv`] function on a [`Receiver`]. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum RecvError { + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvError::Closed => write!(f, "channel closed"), + RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for RecvError {} + + /// An error returned from the [`try_recv`] function on a [`Receiver`]. + /// + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum TryRecvError { + /// The channel is currently empty. There are still active + /// [`Sender`] handles, so data may yet become available. + /// + /// [`Sender`]: crate::sync::broadcast::Sender + Empty, + + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind and has been forcibly disconnected. + /// Attempting to receive again will return the oldest message still + /// retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(f, "channel empty"), + TryRecvError::Closed => write!(f, "channel closed"), + TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for TryRecvError {} +} + +use self::error::*; + +/// Data shared between senders and receivers. +struct Shared<T> { + /// slots in the channel. + buffer: Box<[RwLock<Slot<T>>]>, + + /// Mask a position -> index. + mask: usize, + + /// Tail of the queue. Includes the rx wait list. + tail: Mutex<Tail>, + + /// Number of outstanding Sender handles. + num_tx: AtomicUsize, +} + +/// Next position to write a value. +struct Tail { + /// Next position to write to. + pos: u64, + + /// Number of active receivers. + rx_cnt: usize, + + /// True if the channel is closed. + closed: bool, + + /// Receivers waiting for a value. + waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, +} + +/// Slot in the buffer. +struct Slot<T> { + /// Remaining number of receivers that are expected to see this value. + /// + /// When this goes to zero, the value is released. + /// + /// An atomic is used as it is mutated concurrently with the slot read lock + /// acquired. + rem: AtomicUsize, + + /// Uniquely identifies the `send` stored in the slot. + pos: u64, + + /// True signals the channel is closed. + closed: bool, + + /// The value being broadcast. + /// + /// The value is set by `send` when the write lock is held. When a reader + /// drops, `rem` is decremented. When it hits zero, the value is dropped. + val: UnsafeCell<Option<T>>, +} + +/// An entry in the wait queue. +struct Waiter { + /// True if queued. + queued: bool, + + /// Task waiting on the broadcast channel. + waker: Option<Waker>, + + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers<Waiter>, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +struct RecvGuard<'a, T> { + slot: RwLockReadGuard<'a, Slot<T>>, +} + +/// Receive a value future. +struct Recv<'a, T> { + /// Receiver being waited on. + receiver: &'a mut Receiver<T>, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, +} + +unsafe impl<'a, T: Send> Send for Recv<'a, T> {} +unsafe impl<'a, T: Send> Sync for Recv<'a, T> {} + +/// Max number of receivers. Reserve space to lock. +const MAX_RECEIVERS: usize = usize::MAX >> 2; + +/// Create a bounded, multi-producer, multi-consumer channel where each sent +/// value is broadcasted to all active receivers. +/// +/// All data sent on [`Sender`] will become available on every active +/// [`Receiver`] in the same order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple +/// points in the process or it can be used concurrently from an `Arc`. New +/// `Receiver` handles are created by calling [`Sender::subscribe`]. +/// +/// If all [`Receiver`] handles are dropped, the `send` method will return a +/// [`SendError`]. Similarly, if all [`Sender`] handles are dropped, the [`recv`] +/// method will return a [`RecvError`]. +/// +/// [`Sender`]: crate::sync::broadcast::Sender +/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +/// [`Receiver`]: crate::sync::broadcast::Receiver +/// [`recv`]: crate::sync::broadcast::Receiver::recv +/// [`SendError`]: crate::sync::broadcast::error::SendError +/// [`RecvError`]: crate::sync::broadcast::error::RecvError +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { + assert!(capacity > 0, "capacity is empty"); + assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); + + // Round to a power of two + capacity = capacity.next_power_of_two(); + + let mut buffer = Vec::with_capacity(capacity); + + for i in 0..capacity { + buffer.push(RwLock::new(Slot { + rem: AtomicUsize::new(0), + pos: (i as u64).wrapping_sub(capacity as u64), + closed: false, + val: UnsafeCell::new(None), + })); + } + + let shared = Arc::new(Shared { + buffer: buffer.into_boxed_slice(), + mask: capacity - 1, + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: 1, + closed: false, + waiters: LinkedList::new(), + }), + num_tx: AtomicUsize::new(1), + }); + + let rx = Receiver { + shared: shared.clone(), + next: 0, + }; + + let tx = Sender { shared }; + + (tx, rx) +} + +unsafe impl<T: Send> Send for Sender<T> {} +unsafe impl<T: Send> Sync for Sender<T> {} + +unsafe impl<T: Send> Send for Receiver<T> {} +unsafe impl<T: Send> Sync for Receiver<T> {} + +impl<T> Sender<T> { + /// Attempts to send a value to all active [`Receiver`] handles, returning + /// it back if it could not be sent. + /// + /// A successful send occurs when there is at least one active [`Receiver`] + /// handle. An unsuccessful send would be one where all associated + /// [`Receiver`] handles have already been dropped. + /// + /// # Return + /// + /// On success, the number of subscribed [`Receiver`] handles is returned. + /// This does not mean that this number of receivers will see the message as + /// a receiver may drop before receiving the message. + /// + /// # Note + /// + /// A return value of `Ok` **does not** mean that the sent value will be + /// observed by all or any of the active [`Receiver`] handles. [`Receiver`] + /// handles may be dropped before receiving the sent message. + /// + /// A return value of `Err` **does not** mean that future calls to `send` + /// will fail. New [`Receiver`] handles may be created by calling + /// [`subscribe`]. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + pub fn send(&self, value: T) -> Result<usize, SendError<T>> { + self.send2(Some(value)) + .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + } + + /// Creates a new [`Receiver`] handle that will receive values sent **after** + /// this call to `subscribe`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = broadcast::channel(16); + /// + /// // Will not be seen + /// tx.send(10).unwrap(); + /// + /// let mut rx = tx.subscribe(); + /// + /// tx.send(20).unwrap(); + /// + /// let value = rx.recv().await.unwrap(); + /// assert_eq!(20, value); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + new_receiver(shared) + } + + /// Returns the number of active receivers + /// + /// An active receiver is a [`Receiver`] handle returned from [`channel`] or + /// [`subscribe`]. These are the handles that will receive values sent on + /// this [`Sender`]. + /// + /// # Note + /// + /// It is not guaranteed that a sent message will reach this number of + /// receivers. Active receivers may never call [`recv`] again before + /// dropping. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`Sender`]: crate::sync::broadcast::Sender + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// [`channel`]: crate::sync::broadcast::channel + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx1) = broadcast::channel(16); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = tx.subscribe(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// + /// tx.send(10).unwrap(); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + let tail = self.shared.tail.lock(); + tail.rx_cnt + } + + fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { + let mut tail = self.shared.tail.lock(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let mut slot = self.shared.buffer[idx].write().unwrap(); + + // Track the position + slot.pos = pos; + + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); + + // Set the closed bit if the value is `None`; otherwise write the value + if value.is_none() { + tail.closed = true; + slot.closed = true; + } else { + slot.val.with_mut(|ptr| unsafe { *ptr = value }); + } + + // Release the slot lock before notifying the receivers. + drop(slot); + + tail.notify_rx(); + + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + + Ok(rem) + } +} + +fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { + let mut tail = shared.tail.lock(); + + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + + let next = tail.pos; + + drop(tail); + + Receiver { shared, next } +} + +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.queued); + waiter.queued = false; + + let waker = waiter.waker.take().unwrap(); + waker.wake(); + } + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Sender<T> { + let shared = self.shared.clone(); + shared.num_tx.fetch_add(1, SeqCst); + + Sender { shared } + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { + let _ = self.send2(None); + } + } +} + +impl<T> Receiver<T> { + /// Locks the next value if there is one. + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell<Waiter>, &Waker)>, + ) -> Result<RecvGuard<'_, T>, TryRecvError> { + let idx = (self.next & self.shared.mask as u64) as usize; + + // The slot holding the next value to read + let mut slot = self.shared.buffer[idx].read().unwrap(); + + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { + return Err(TryRecvError::Empty); + } + + // Release the `slot` lock before attempting to acquire the `tail` + // lock. This is required because `send2` acquires the tail lock + // first followed by the slot lock. Acquiring the locks in reverse + // order here would result in a potential deadlock: `recv_ref` + // acquires the `slot` lock and attempts to acquire the `tail` lock + // while `send2` acquired the `tail` lock and attempts to acquire + // the slot lock. + drop(slot); + + let mut tail = self.shared.tail.lock(); + + // Acquire slot lock again + slot = self.shared.buffer[idx].read().unwrap(); + + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } + + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + + let missed = next.wrapping_sub(self.next); + + drop(tail); + + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); + + return Ok(RecvGuard { slot }); + } + + self.next = next; + + return Err(TryRecvError::Lagged(missed)); + } + } + + self.next = self.next.wrapping_add(1); + + if slot.closed { + return Err(TryRecvError::Closed); + } + + Ok(RecvGuard { slot }) + } +} + +impl<T: Clone> Receiver<T> { + /// Receives the next value for this receiver. + /// + /// Each [`Receiver`] handle will receive a clone of all values sent + /// **after** it has subscribed. + /// + /// `Err(RecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(RecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`recv`] will return this value + /// **unless** it has been since overwritten. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + /// + /// Handling lag + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// // The receiver lagged behind + /// assert!(rx.recv().await.is_err()); + /// + /// // At this point, we can abort or continue with lost messages + /// + /// assert_eq!(20, rx.recv().await.unwrap()); + /// assert_eq!(30, rx.recv().await.unwrap()); + /// } + /// ``` + pub async fn recv(&mut self) -> Result<T, RecvError> { + let fut = Recv::new(self); + fut.await + } + + /// Attempts to return a pending value on this receiver without awaiting. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// await on a receiver. + /// + /// Compared with [`recv`], this function has three failure cases instead of two + /// (one for closed, one for an empty buffer, one for a lagging receiver). + /// + /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`try_recv`] will return this value + /// **unless** it has been since overwritten. If there are no values to + /// receive, `Err(TryRecvError::Empty)` is returned. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// assert!(rx.try_recv().is_err()); + /// + /// tx.send(10).unwrap(); + /// + /// let value = rx.try_recv().unwrap(); + /// assert_eq!(10, value); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let guard = self.recv_ref(None)?; + guard.clone_value().ok_or(TryRecvError::Closed) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + let mut tail = self.shared.tail.lock(); + + tail.rx_cnt -= 1; + let until = tail.pos; + + drop(tail); + + while self.next < until { + match self.recv_ref(None) { + Ok(_) => {} + // The channel is closed + Err(TryRecvError::Closed) => break, + // Ignore lagging, we will catch up + Err(TryRecvError::Lagged(..)) => {} + // Can't be empty + Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"), + } + } + } +} + +impl<'a, T> Recv<'a, T> { + fn new(receiver: &'a mut Receiver<T>) -> Recv<'a, T> { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + } + } + + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver<T>>(); + + let me = self.get_unchecked_mut(); + (me.receiver, &me.waiter) + } + } +} + +impl<'a, T> Future for Recv<'a, T> +where + T: Clone, +{ + type Output = Result<T, RecvError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) + } +} + +impl<'a, T> Drop for Recv<'a, T> { + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.shared.tail.lock(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Sender") + } +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Receiver") + } +} + +impl<'a, T> RecvGuard<'a, T> { + fn clone_value(&self) -> Option<T> + where + T: Clone, + { + self.slot.val.with(|ptr| unsafe { (*ptr).clone() }) + } +} + +impl<'a, T> Drop for RecvGuard<'a, T> { + fn drop(&mut self) { + // Decrement the remaining counter + if 1 == self.slot.rem.fetch_sub(1, SeqCst) { + // Safety: Last receiver, drop the value + self.slot.val.with_mut(|ptr| unsafe { *ptr = None }); + } + } +} + +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/sync/mod.rs b/third_party/rust/tokio/src/sync/mod.rs new file mode 100644 index 0000000000..457e6ab294 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mod.rs @@ -0,0 +1,499 @@ +#![cfg_attr(loom, allow(dead_code, unreachable_pub, unused_imports))] + +//! Synchronization primitives for use in asynchronous contexts. +//! +//! Tokio programs tend to be organized as a set of [tasks] where each task +//! operates independently and may be executed on separate physical threads. The +//! synchronization primitives provided in this module permit these independent +//! tasks to communicate together. +//! +//! [tasks]: crate::task +//! +//! # Message passing +//! +//! The most common form of synchronization in a Tokio program is message +//! passing. Two tasks operate independently and send messages to each other to +//! synchronize. Doing so has the advantage of avoiding shared state. +//! +//! Message passing is implemented using channels. A channel supports sending a +//! message from one producer task to one or more consumer tasks. There are a +//! few flavors of channels provided by Tokio. Each channel flavor supports +//! different message passing patterns. When a channel supports multiple +//! producers, many separate tasks may **send** messages. When a channel +//! supports multiple consumers, many different separate tasks may **receive** +//! messages. +//! +//! Tokio provides many different channel flavors as different message passing +//! patterns are best handled with different implementations. +//! +//! ## `oneshot` channel +//! +//! The [`oneshot` channel][oneshot] supports sending a **single** value from a +//! single producer to a single consumer. This channel is usually used to send +//! the result of a computation to a waiter. +//! +//! **Example:** using a [`oneshot` channel][oneshot] to receive the result of a +//! computation. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! async fn some_computation() -> String { +//! "represents the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! let res = some_computation().await; +//! tx.send(res).unwrap(); +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = rx.await.unwrap(); +//! } +//! ``` +//! +//! Note, if the task produces a computation result as its final +//! action before terminating, the [`JoinHandle`] can be used to +//! receive that value instead of allocating resources for the +//! `oneshot` channel. Awaiting on [`JoinHandle`] returns `Result`. If +//! the task panics, the `Joinhandle` yields `Err` with the panic +//! cause. +//! +//! **Example:** +//! +//! ``` +//! async fn some_computation() -> String { +//! "the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let join_handle = tokio::spawn(async move { +//! some_computation().await +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = join_handle.await.unwrap(); +//! } +//! ``` +//! +//! [oneshot]: oneshot +//! [`JoinHandle`]: crate::task::JoinHandle +//! +//! ## `mpsc` channel +//! +//! The [`mpsc` channel][mpsc] supports sending **many** values from **many** +//! producers to a single consumer. This channel is often used to send work to a +//! task or to receive the result of many computations. +//! +//! **Example:** using an mpsc to incrementally stream the results of a series +//! of computations. +//! +//! ``` +//! use tokio::sync::mpsc; +//! +//! async fn some_computation(input: u32) -> String { +//! format!("the result of computation {}", input) +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx) = mpsc::channel(100); +//! +//! tokio::spawn(async move { +//! for i in 0..10 { +//! let res = some_computation(i).await; +//! tx.send(res).await.unwrap(); +//! } +//! }); +//! +//! while let Some(res) = rx.recv().await { +//! println!("got = {}", res); +//! } +//! } +//! ``` +//! +//! The argument to `mpsc::channel` is the channel capacity. This is the maximum +//! number of values that can be stored in the channel pending receipt at any +//! given time. Properly setting this value is key in implementing robust +//! programs as the channel capacity plays a critical part in handling back +//! pressure. +//! +//! A common concurrency pattern for resource management is to spawn a task +//! dedicated to managing that resource and using message passing between other +//! tasks to interact with the resource. The resource may be anything that may +//! not be concurrently used. Some examples include a socket and program state. +//! For example, if multiple tasks need to send data over a single socket, spawn +//! a task to manage the socket and use a channel to synchronize. +//! +//! **Example:** sending data from many tasks over a single socket using message +//! passing. +//! +//! ```no_run +//! use tokio::io::{self, AsyncWriteExt}; +//! use tokio::net::TcpStream; +//! use tokio::sync::mpsc; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let mut socket = TcpStream::connect("www.example.com:1234").await?; +//! let (tx, mut rx) = mpsc::channel(100); +//! +//! for _ in 0..10 { +//! // Each task needs its own `tx` handle. This is done by cloning the +//! // original handle. +//! let tx = tx.clone(); +//! +//! tokio::spawn(async move { +//! tx.send(&b"data to write"[..]).await.unwrap(); +//! }); +//! } +//! +//! // The `rx` half of the channel returns `None` once **all** `tx` clones +//! // drop. To ensure `None` is returned, drop the handle owned by the +//! // current task. If this `tx` handle is not dropped, there will always +//! // be a single outstanding `tx` handle. +//! drop(tx); +//! +//! while let Some(res) = rx.recv().await { +//! socket.write_all(res).await?; +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! The [`mpsc`][mpsc] and [`oneshot`][oneshot] channels can be combined to +//! provide a request / response type synchronization pattern with a shared +//! resource. A task is spawned to synchronize a resource and waits on commands +//! received on a [`mpsc`][mpsc] channel. Each command includes a +//! [`oneshot`][oneshot] `Sender` on which the result of the command is sent. +//! +//! **Example:** use a task to synchronize a `u64` counter. Each task sends an +//! "fetch and increment" command. The counter value **before** the increment is +//! sent over the provided `oneshot` channel. +//! +//! ``` +//! use tokio::sync::{oneshot, mpsc}; +//! use Command::Increment; +//! +//! enum Command { +//! Increment, +//! // Other commands can be added here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (cmd_tx, mut cmd_rx) = mpsc::channel::<(Command, oneshot::Sender<u64>)>(100); +//! +//! // Spawn a task to manage the counter +//! tokio::spawn(async move { +//! let mut counter: u64 = 0; +//! +//! while let Some((cmd, response)) = cmd_rx.recv().await { +//! match cmd { +//! Increment => { +//! let prev = counter; +//! counter += 1; +//! response.send(prev).unwrap(); +//! } +//! } +//! } +//! }); +//! +//! let mut join_handles = vec![]; +//! +//! // Spawn tasks that will send the increment command. +//! for _ in 0..10 { +//! let cmd_tx = cmd_tx.clone(); +//! +//! join_handles.push(tokio::spawn(async move { +//! let (resp_tx, resp_rx) = oneshot::channel(); +//! +//! cmd_tx.send((Increment, resp_tx)).await.ok().unwrap(); +//! let res = resp_rx.await.unwrap(); +//! +//! println!("previous value = {}", res); +//! })); +//! } +//! +//! // Wait for all tasks to complete +//! for join_handle in join_handles.drain(..) { +//! join_handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! [mpsc]: mpsc +//! +//! ## `broadcast` channel +//! +//! The [`broadcast` channel] supports sending **many** values from +//! **many** producers to **many** consumers. Each consumer will receive +//! **each** value. This channel can be used to implement "fan out" style +//! patterns common with pub / sub or "chat" systems. +//! +//! This channel tends to be used less often than `oneshot` and `mpsc` but still +//! has its use cases. +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! [`broadcast` channel]: crate::sync::broadcast +//! +//! ## `watch` channel +//! +//! The [`watch` channel] supports sending **many** values from a **single** +//! producer to **many** consumers. However, only the **most recent** value is +//! stored in the channel. Consumers are notified when a new value is sent, but +//! there is no guarantee that consumers will see **all** values. +//! +//! The [`watch` channel] is similar to a [`broadcast` channel] with capacity 1. +//! +//! Use cases for the [`watch` channel] include broadcasting configuration +//! changes or signalling program state changes, such as transitioning to +//! shutdown. +//! +//! **Example:** use a [`watch` channel] to notify tasks of configuration +//! changes. In this example, a configuration file is checked periodically. When +//! the file changes, the configuration changes are signalled to consumers. +//! +//! ``` +//! use tokio::sync::watch; +//! use tokio::time::{self, Duration, Instant}; +//! +//! use std::io; +//! +//! #[derive(Debug, Clone, Eq, PartialEq)] +//! struct Config { +//! timeout: Duration, +//! } +//! +//! impl Config { +//! async fn load_from_file() -> io::Result<Config> { +//! // file loading and deserialization logic here +//! # Ok(Config { timeout: Duration::from_secs(1) }) +//! } +//! } +//! +//! async fn my_async_operation() { +//! // Do something here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! // Load initial configuration value +//! let mut config = Config::load_from_file().await.unwrap(); +//! +//! // Create the watch channel, initialized with the loaded configuration +//! let (tx, rx) = watch::channel(config.clone()); +//! +//! // Spawn a task to monitor the file. +//! tokio::spawn(async move { +//! loop { +//! // Wait 10 seconds between checks +//! time::sleep(Duration::from_secs(10)).await; +//! +//! // Load the configuration file +//! let new_config = Config::load_from_file().await.unwrap(); +//! +//! // If the configuration changed, send the new config value +//! // on the watch channel. +//! if new_config != config { +//! tx.send(new_config.clone()).unwrap(); +//! config = new_config; +//! } +//! } +//! }); +//! +//! let mut handles = vec![]; +//! +//! // Spawn tasks that runs the async operation for at most `timeout`. If +//! // the timeout elapses, restart the operation. +//! // +//! // The task simultaneously watches the `Config` for changes. When the +//! // timeout duration changes, the timeout is updated without restarting +//! // the in-flight operation. +//! for _ in 0..5 { +//! // Clone a config watch handle for use in this task +//! let mut rx = rx.clone(); +//! +//! let handle = tokio::spawn(async move { +//! // Start the initial operation and pin the future to the stack. +//! // Pinning to the stack is required to resume the operation +//! // across multiple calls to `select!` +//! let op = my_async_operation(); +//! tokio::pin!(op); +//! +//! // Get the initial config value +//! let mut conf = rx.borrow().clone(); +//! +//! let mut op_start = Instant::now(); +//! let sleep = time::sleep_until(op_start + conf.timeout); +//! tokio::pin!(sleep); +//! +//! loop { +//! tokio::select! { +//! _ = &mut sleep => { +//! // The operation elapsed. Restart it +//! op.set(my_async_operation()); +//! +//! // Track the new start time +//! op_start = Instant::now(); +//! +//! // Restart the timeout +//! sleep.set(time::sleep_until(op_start + conf.timeout)); +//! } +//! _ = rx.changed() => { +//! conf = rx.borrow().clone(); +//! +//! // The configuration has been updated. Update the +//! // `sleep` using the new `timeout` value. +//! sleep.as_mut().reset(op_start + conf.timeout); +//! } +//! _ = &mut op => { +//! // The operation completed! +//! return +//! } +//! } +//! } +//! }); +//! +//! handles.push(handle); +//! } +//! +//! for handle in handles.drain(..) { +//! handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! [`watch` channel]: mod@crate::sync::watch +//! [`broadcast` channel]: mod@crate::sync::broadcast +//! +//! # State synchronization +//! +//! The remaining synchronization primitives focus on synchronizing state. +//! These are asynchronous equivalents to versions provided by `std`. They +//! operate in a similar way as their `std` counterparts but will wait +//! asynchronously instead of blocking the thread. +//! +//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to +//! reach a point in the program, before continuing execution all together. +//! +//! * [`Mutex`](Mutex) Mutual Exclusion mechanism, which ensures that at most +//! one thread at a time is able to access some data. +//! +//! * [`Notify`](Notify) Basic task notification. `Notify` supports notifying a +//! receiving task without sending data. In this case, the task wakes up and +//! resumes processing. +//! +//! * [`RwLock`](RwLock) Provides a mutual exclusion mechanism which allows +//! multiple readers at the same time, while allowing only one writer at a +//! time. In some cases, this can be more efficient than a mutex. +//! +//! * [`Semaphore`](Semaphore) Limits the amount of concurrency. A semaphore +//! holds a number of permits, which tasks may request in order to enter a +//! critical section. Semaphores are useful for implementing limiting or +//! bounding of any kind. + +cfg_sync! { + /// Named future types. + pub mod futures { + pub use super::notify::Notified; + } + + mod barrier; + pub use barrier::{Barrier, BarrierWaitResult}; + + pub mod broadcast; + + pub mod mpsc; + + mod mutex; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard, MappedMutexGuard}; + + pub(crate) mod notify; + pub use notify::Notify; + + pub mod oneshot; + + pub(crate) mod batch_semaphore; + pub use batch_semaphore::{AcquireError, TryAcquireError}; + + mod semaphore; + pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; + + mod rwlock; + pub use rwlock::RwLock; + pub use rwlock::owned_read_guard::OwnedRwLockReadGuard; + pub use rwlock::owned_write_guard::OwnedRwLockWriteGuard; + pub use rwlock::owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; + pub use rwlock::read_guard::RwLockReadGuard; + pub use rwlock::write_guard::RwLockWriteGuard; + pub use rwlock::write_guard_mapped::RwLockMappedWriteGuard; + + mod task; + pub(crate) use task::AtomicWaker; + + mod once_cell; + pub use self::once_cell::{OnceCell, SetError}; + + pub mod watch; +} + +cfg_not_sync! { + cfg_fs! { + pub(crate) mod batch_semaphore; + mod mutex; + pub(crate) use mutex::Mutex; + } + + #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod notify; + + #[cfg(any(feature = "rt", all(windows, feature = "process")))] + pub(crate) mod oneshot; + + cfg_atomic_waker_impl! { + mod task; + pub(crate) use task::AtomicWaker; + } + + #[cfg(any(feature = "signal", all(unix, feature = "process")))] + pub(crate) mod watch; +} + +/// Unit tests +#[cfg(test)] +mod tests; diff --git a/third_party/rust/tokio/src/sync/mpsc/block.rs b/third_party/rust/tokio/src/sync/mpsc/block.rs new file mode 100644 index 0000000000..58f4a9f6cc --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/block.rs @@ -0,0 +1,385 @@ +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; + +use std::mem::MaybeUninit; +use std::ops; +use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Release}; + +/// A block in a linked list. +/// +/// Each block in the list can hold up to `BLOCK_CAP` messages. +pub(crate) struct Block<T> { + /// The start index of this block. + /// + /// Slots in this block have indices in `start_index .. start_index + BLOCK_CAP`. + start_index: usize, + + /// The next block in the linked list. + next: AtomicPtr<Block<T>>, + + /// Bitfield tracking slots that are ready to have their values consumed. + ready_slots: AtomicUsize, + + /// The observed `tail_position` value *after* the block has been passed by + /// `block_tail`. + observed_tail_position: UnsafeCell<usize>, + + /// Array containing values pushed into the block. Values are stored in a + /// continuous array in order to improve cache line behavior when reading. + /// The values must be manually dropped. + values: Values<T>, +} + +pub(crate) enum Read<T> { + Value(T), + Closed, +} + +struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); + +use super::BLOCK_CAP; + +/// Masks an index to get the block identifier. +const BLOCK_MASK: usize = !(BLOCK_CAP - 1); + +/// Masks an index to get the value offset in a block. +const SLOT_MASK: usize = BLOCK_CAP - 1; + +/// Flag tracking that a block has gone through the sender's release routine. +/// +/// When this is set, the receiver may consider freeing the block. +const RELEASED: usize = 1 << BLOCK_CAP; + +/// Flag tracking all senders dropped. +/// +/// When this flag is set, the send half of the channel has closed. +const TX_CLOSED: usize = RELEASED << 1; + +/// Mask covering all bits used to track slot readiness. +const READY_MASK: usize = RELEASED - 1; + +/// Returns the index of the first slot in the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn start_index(slot_index: usize) -> usize { + BLOCK_MASK & slot_index +} + +/// Returns the offset into the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn offset(slot_index: usize) -> usize { + SLOT_MASK & slot_index +} + +impl<T> Block<T> { + pub(crate) fn new(start_index: usize) -> Block<T> { + Block { + // The absolute index in the channel of the first slot in the block. + start_index, + + // Pointer to the next block in the linked list. + next: AtomicPtr::new(ptr::null_mut()), + + ready_slots: AtomicUsize::new(0), + + observed_tail_position: UnsafeCell::new(0), + + // Value storage + values: unsafe { Values::uninitialized() }, + } + } + + /// Returns `true` if the block matches the given index. + pub(crate) fn is_at_index(&self, index: usize) -> bool { + debug_assert!(offset(index) == 0); + self.start_index == index + } + + /// Returns the number of blocks between `self` and the block at the + /// specified index. + /// + /// `start_index` must represent a block *after* `self`. + pub(crate) fn distance(&self, other_index: usize) -> usize { + debug_assert!(offset(other_index) == 0); + other_index.wrapping_sub(self.start_index) / BLOCK_CAP + } + + /// Reads the value at the given offset. + /// + /// Returns `None` if the slot is empty. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * No concurrent access to the slot. + pub(crate) unsafe fn read(&self, slot_index: usize) -> Option<Read<T>> { + let offset = offset(slot_index); + + let ready_bits = self.ready_slots.load(Acquire); + + if !is_ready(ready_bits, offset) { + if is_tx_closed(ready_bits) { + return Some(Read::Closed); + } + + return None; + } + + // Get the value + let value = self.values[offset].with(|ptr| ptr::read(ptr)); + + Some(Read::Value(value.assume_init())) + } + + /// Writes a value to the block at the given offset. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The slot is empty. + /// * No concurrent access to the slot. + pub(crate) unsafe fn write(&self, slot_index: usize, value: T) { + // Get the offset into the block + let slot_offset = offset(slot_index); + + self.values[slot_offset].with_mut(|ptr| { + ptr::write(ptr, MaybeUninit::new(value)); + }); + + // Release the value. After this point, the slot ref may no longer + // be used. It is possible for the receiver to free the memory at + // any point. + self.set_ready(slot_offset); + } + + /// Signal to the receiver that the sender half of the list is closed. + pub(crate) unsafe fn tx_close(&self) { + self.ready_slots.fetch_or(TX_CLOSED, Release); + } + + /// Resets the block to a blank state. This enables reusing blocks in the + /// channel. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * All slots are empty. + /// * The caller holds a unique pointer to the block. + pub(crate) unsafe fn reclaim(&mut self) { + self.start_index = 0; + self.next = AtomicPtr::new(ptr::null_mut()); + self.ready_slots = AtomicUsize::new(0); + } + + /// Releases the block to the rx half for freeing. + /// + /// This function is called by the tx half once it can be guaranteed that no + /// more senders will attempt to access the block. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The block will no longer be accessed by any sender. + pub(crate) unsafe fn tx_release(&self, tail_position: usize) { + // Track the observed tail_position. Any sender targeting a greater + // tail_position is guaranteed to not access this block. + self.observed_tail_position + .with_mut(|ptr| *ptr = tail_position); + + // Set the released bit, signalling to the receiver that it is safe to + // free the block's memory as soon as all slots **prior** to + // `observed_tail_position` have been filled. + self.ready_slots.fetch_or(RELEASED, Release); + } + + /// Mark a slot as ready + fn set_ready(&self, slot: usize) { + let mask = 1 << slot; + self.ready_slots.fetch_or(mask, Release); + } + + /// Returns `true` when all slots have their `ready` bits set. + /// + /// This indicates that the block is in its final state and will no longer + /// be mutated. + /// + /// # Implementation + /// + /// The implementation walks each slot checking the `ready` flag. It might + /// be that it would make more sense to coalesce ready flags as bits in a + /// single atomic cell. However, this could have negative impact on cache + /// behavior as there would be many more mutations to a single slot. + pub(crate) fn is_final(&self) -> bool { + self.ready_slots.load(Acquire) & READY_MASK == READY_MASK + } + + /// Returns the `observed_tail_position` value, if set + pub(crate) fn observed_tail_position(&self) -> Option<usize> { + if 0 == RELEASED & self.ready_slots.load(Acquire) { + None + } else { + Some(self.observed_tail_position.with(|ptr| unsafe { *ptr })) + } + } + + /// Loads the next block + pub(crate) fn load_next(&self, ordering: Ordering) -> Option<NonNull<Block<T>>> { + let ret = NonNull::new(self.next.load(ordering)); + + debug_assert!(unsafe { + ret.map(|block| block.as_ref().start_index == self.start_index.wrapping_add(BLOCK_CAP)) + .unwrap_or(true) + }); + + ret + } + + /// Pushes `block` as the next block in the link. + /// + /// Returns Ok if successful, otherwise, a pointer to the next block in + /// the list is returned. + /// + /// This requires that the next pointer is null. + /// + /// # Ordering + /// + /// This performs a compare-and-swap on `next` using AcqRel ordering. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * `block` is not freed until it has been removed from the list. + pub(crate) unsafe fn try_push( + &self, + block: &mut NonNull<Block<T>>, + success: Ordering, + failure: Ordering, + ) -> Result<(), NonNull<Block<T>>> { + block.as_mut().start_index = self.start_index.wrapping_add(BLOCK_CAP); + + let next_ptr = self + .next + .compare_exchange(ptr::null_mut(), block.as_ptr(), success, failure) + .unwrap_or_else(|x| x); + + match NonNull::new(next_ptr) { + Some(next_ptr) => Err(next_ptr), + None => Ok(()), + } + } + + /// Grows the `Block` linked list by allocating and appending a new block. + /// + /// The next block in the linked list is returned. This may or may not be + /// the one allocated by the function call. + /// + /// # Implementation + /// + /// It is assumed that `self.next` is null. A new block is allocated with + /// `start_index` set to be the next block. A compare-and-swap is performed + /// with AcqRel memory ordering. If the compare-and-swap is successful, the + /// newly allocated block is released to other threads walking the block + /// linked list. If the compare-and-swap fails, the current thread acquires + /// the next block in the linked list, allowing the current thread to access + /// the slots. + pub(crate) fn grow(&self) -> NonNull<Block<T>> { + // Create the new block. It is assumed that the block will become the + // next one after `&self`. If this turns out to not be the case, + // `start_index` is updated accordingly. + let new_block = Box::new(Block::new(self.start_index + BLOCK_CAP)); + + let mut new_block = unsafe { NonNull::new_unchecked(Box::into_raw(new_block)) }; + + // Attempt to store the block. The first compare-and-swap attempt is + // "unrolled" due to minor differences in logic + // + // `AcqRel` is used as the ordering **only** when attempting the + // compare-and-swap on self.next. + // + // If the compare-and-swap fails, then the actual value of the cell is + // returned from this function and accessed by the caller. Given this, + // the memory must be acquired. + // + // `Release` ensures that the newly allocated block is available to + // other threads acquiring the next pointer. + let next = NonNull::new( + self.next + .compare_exchange(ptr::null_mut(), new_block.as_ptr(), AcqRel, Acquire) + .unwrap_or_else(|x| x), + ); + + let next = match next { + Some(next) => next, + None => { + // The compare-and-swap succeeded and the newly allocated block + // is successfully pushed. + return new_block; + } + }; + + // There already is a next block in the linked list. The newly allocated + // block could be dropped and the discovered next block returned; + // however, that would be wasteful. Instead, the linked list is walked + // by repeatedly attempting to compare-and-swap the pointer into the + // `next` register until the compare-and-swap succeed. + // + // Care is taken to update new_block's start_index field as appropriate. + + let mut curr = next; + + // TODO: Should this iteration be capped? + loop { + let actual = unsafe { curr.as_ref().try_push(&mut new_block, AcqRel, Acquire) }; + + curr = match actual { + Ok(_) => { + return next; + } + Err(curr) => curr, + }; + + crate::loom::thread::yield_now(); + } + } +} + +/// Returns `true` if the specified slot has a value ready to be consumed. +fn is_ready(bits: usize, slot: usize) -> bool { + let mask = 1 << slot; + mask == mask & bits +} + +/// Returns `true` if the closed flag has been set. +fn is_tx_closed(bits: usize) -> bool { + TX_CLOSED == bits & TX_CLOSED +} + +impl<T> Values<T> { + unsafe fn uninitialized() -> Values<T> { + let mut vals = MaybeUninit::uninit(); + + // When fuzzing, `UnsafeCell` needs to be initialized. + if_loom! { + let p = vals.as_mut_ptr() as *mut UnsafeCell<MaybeUninit<T>>; + for i in 0..BLOCK_CAP { + p.add(i) + .write(UnsafeCell::new(MaybeUninit::uninit())); + } + } + + Values(vals.assume_init()) + } +} + +impl<T> ops::Index<usize> for Values<T> { + type Output = UnsafeCell<MaybeUninit<T>>; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/bounded.rs b/third_party/rust/tokio/src/sync/mpsc/bounded.rs new file mode 100644 index 0000000000..ddded8ebb3 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/bounded.rs @@ -0,0 +1,1197 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; + +cfg_time! { + use crate::sync::mpsc::error::SendTimeoutError; + use crate::time::Duration; +} + +use std::fmt; +use std::task::{Context, Poll}; + +/// Sends values to the associated `Receiver`. +/// +/// Instances are created by the [`channel`](channel) function. +/// +/// To convert the `Sender` into a `Sink` or use it in a poll function, you can +/// use the [`PollSender`] utility. +/// +/// [`PollSender`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSender.html +pub struct Sender<T> { + chan: chan::Tx<T, Semaphore>, +} + +/// Permits to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] +/// and are used to guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +/// [`Sender::try_reserve()`]: Sender::try_reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx<T, Semaphore>, +} + +/// Owned permit to send one value into the channel. +/// +/// This is identical to the [`Permit`] type, except that it moves the sender +/// rather than borrowing it. +/// +/// `OwnedPermit` values are returned by [`Sender::reserve_owned()`] and +/// [`Sender::try_reserve_owned()`] and are used to guarantee channel capacity +/// before generating a message to send. +/// +/// [`Permit`]: Permit +/// [`Sender::reserve_owned()`]: Sender::reserve_owned +/// [`Sender::try_reserve_owned()`]: Sender::try_reserve_owned +pub struct OwnedPermit<T> { + chan: Option<chan::Tx<T, Semaphore>>, +} + +/// Receives values from the associated `Sender`. +/// +/// Instances are created by the [`channel`](channel) function. +/// +/// This receiver can be turned into a `Stream` using [`ReceiverStream`]. +/// +/// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html +pub struct Receiver<T> { + /// The channel receiver. + chan: chan::Rx<T, Semaphore>, +} + +/// Creates a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to send new messages will wait until a message is +/// received from the channel. The provided buffer capacity must be at least 1. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is supported. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return `None`. +/// +/// # Panics +/// +/// Panics if the buffer capacity is 0. +/// +/// # Examples +/// +/// ```rust +/// use tokio::sync::mpsc; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx) = mpsc::channel(100); +/// +/// tokio::spawn(async move { +/// for i in 0..10 { +/// if let Err(_) = tx.send(i).await { +/// println!("receiver dropped"); +/// return; +/// } +/// } +/// }); +/// +/// while let Some(i) = rx.recv().await { +/// println!("got = {}", i); +/// } +/// } +/// ``` +pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) { + assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); + let semaphore = (semaphore::Semaphore::new(buffer), buffer); + let (tx, rx) = chan::channel(semaphore); + + let tx = Sender::new(tx); + let rx = Receiver::new(rx); + + (tx, rx) +} + +/// Channel semaphore is a tuple of the semaphore implementation and a `usize` +/// representing the channel bound. +type Semaphore = (semaphore::Semaphore, usize); + +impl<T> Receiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> Receiver<T> { + Receiver { chan } + } + + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. Note that if [`close`] is called, but there are + /// still outstanding [`Permits`] from before it was closed, the channel is + /// not considered closed by `recv` until the permits are released. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`close`]: Self::close + /// [`Permits`]: struct@crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tokio::spawn(async move { + /// tx.send("hello").await.unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// tx.send("world").await.unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + poll_fn(|cx| self.chan.recv(cx)).await + } + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").await.unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will block until a message is sent or + /// the channel is closed. + /// + /// This method is intended for use cases where you are sending from + /// asynchronous code to synchronous code, and will work even if the sender + /// is not using [`blocking_send`] to send the message. + /// + /// Note that if [`close`] is called, but there are still outstanding + /// [`Permits`] from before it was closed, the channel is not considered + /// closed by `blocking_recv` until the permits are released. + /// + /// [`close`]: Self::close + /// [`Permits`]: struct@crate::sync::mpsc::Permit + /// [`blocking_send`]: fn@crate::sync::mpsc::Sender::blocking_send + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(10); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// Runtime::new() + /// .unwrap() + /// .block_on(async move { + /// let _ = tx.send(10).await; + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. If there are + /// outstanding [`Permit`] or [`OwnedPermit`] values, the `recv` method will + /// not return `None` until those are released. + /// + /// [`Permit`]: Permit + /// [`OwnedPermit`]: OwnedPermit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` + pub fn close(&mut self) { + self.chan.close(); + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + +impl<T> Unpin for Receiver<T> {} + +impl<T> Sender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> Sender<T> { + Sender { chan } + } + + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Cancel safety + /// + /// If `send` is used as the event in a [`tokio::select!`](crate::select) + /// statement and some other branch completes first, then it is guaranteed + /// that the message was not sent. + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `send` makes you lose your place in the queue. + /// + /// # Examples + /// + /// In the following example, each call to `send` will block until the + /// previously sent value was received. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(_) = tx.send(i).await { + /// println!("receiver dropped"); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// } + /// } + /// ``` + pub async fn send(&self, value: T) -> Result<(), SendError<T>> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::channel::<()>(1); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + /// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + + /// Attempts to immediately send a message on this `Sender` + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`channel`], then an + /// error is returned. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`send`]: Sender::send + /// [`channel`]: channel + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create a channel with buffer size 1 + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); + /// + /// tokio::spawn(async move { + /// tx1.send(1).await.unwrap(); + /// tx1.send(2).await.unwrap(); + /// // task waits until the receiver receives a value. + /// }); + /// + /// tokio::spawn(async move { + /// // This will return an error and send + /// // no message if the buffer is full + /// let _ = tx2.try_send(3); + /// }); + /// + /// let mut msg; + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// // Third message may have never been sent + /// match rx.recv().await { + /// Some(msg) => println!("message {} received", msg), + /// None => println!("the third message was never sent"), + /// } + /// } + /// ``` + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); + Ok(()) + } + + /// Sends a value, waiting until there is capacity, but only for a limited time. + /// + /// Shares the same success and error conditions as [`send`], adding one more + /// condition for an unsuccessful send, which is when the provided timeout has + /// elapsed, and there is no capacity available. + /// + /// [`send`]: Sender::send + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] having been dropped, + /// the function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Panics + /// + /// This function panics if it is called outside the context of a Tokio + /// runtime [with time enabled](crate::runtime::Builder::enable_time). + /// + /// # Examples + /// + /// In the following example, each call to `send_timeout` will block until the + /// previously sent value was received, unless the timeout has elapsed. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// use tokio::time::{sleep, Duration}; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(e) = tx.send_timeout(i, Duration::from_millis(100)).await { + /// println!("send error: #{:?}", e); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// sleep(Duration::from_millis(200)).await; + /// } + /// } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + pub async fn send_timeout( + &self, + value: T, + timeout: Duration, + ) -> Result<(), SendTimeoutError<T>> { + let permit = match crate::time::timeout(timeout, self.reserve()).await { + Err(_) => { + return Err(SendTimeoutError::Timeout(value)); + } + Ok(Err(_)) => { + return Err(SendTimeoutError::Closed(value)); + } + Ok(Ok(permit)) => permit, + }; + + permit.send(value); + Ok(()) + } + + /// Blocking send to call outside of asynchronous contexts. + /// + /// This method is intended for use cases where you are sending from + /// synchronous code to asynchronous code, and will work even if the + /// receiver is not using [`blocking_recv`] to receive the message. + /// + /// [`blocking_recv`]: fn@crate::sync::mpsc::Receiver::blocking_recv + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(1); + /// + /// let sync_code = thread::spawn(move || { + /// tx.blocking_send(10).unwrap(); + /// }); + /// + /// Runtime::new().unwrap().block_on(async move { + /// assert_eq!(Some(10), rx.recv().await); + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> { + crate::future::block_on(self.send(value)) + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + /// + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Waits for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve` makes you lose your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> { + self.reserve_inner().await?; + Ok(Permit { chan: &self.chan }) + } + + /// Waits for channel capacity, moving the `Sender` and returning an owned + /// permit. Once capacity to send one message is available, it is reserved + /// for the caller. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `reserve_owned`. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. An [`OwnedPermit`] is returned to + /// track the reserved capacity. The [`send`] function on [`OwnedPermit`] + /// consumes the reserved capacity. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the + /// capacity back to the channel. + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve_owned` makes you lose your place in the queue. + /// + /// # Examples + /// Sending a message using an [`OwnedPermit`]: + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity, moving the sender. + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message, consuming the permit and returning + /// // the moved sender. + /// let tx = permit.send(123); + /// + /// // The value sent on the permit is received. + /// assert_eq!(rx.recv().await.unwrap(), 123); + /// + /// // The sender can now be used again. + /// tx.send(456).await.unwrap(); + /// } + /// ``` + /// + /// When multiple [`OwnedPermit`]s are needed, or the sender cannot be moved + /// by value, it can be inexpensively cloned before calling `reserve_owned`: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity. + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds. + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + /// + /// [`Sender::reserve`]: Sender::reserve + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`Arc::clone`]: std::sync::Arc::clone + pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> { + self.reserve_inner().await?; + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + + async fn reserve_inner(&self) -> Result<(), SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => Ok(()), + Err(_) => Err(SendError(())), + } + } + + /// Tries to acquire a slot in the channel without waiting for the slot to become + /// available. + /// + /// If the channel is full this function will return [`TrySendError`], otherwise + /// if there is a slot available it will return a [`Permit`] that will then allow you + /// to [`send`] on the channel with a guaranteed slot. This function is similar to + /// [`reserve`] except it does not await for the slot to become available. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// [`reserve`]: Sender::reserve + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.try_reserve().unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// } + /// ``` + pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())), + } + + Ok(Permit { chan: &self.chan }) + } + + /// Tries to acquire a slot in the channel without waiting for the slot to become + /// available, returning an owned permit. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::try_reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `try_reserve_owned`. + /// + /// If the channel is full this function will return a [`TrySendError`]. + /// Since the sender is taken by value, the `TrySendError` returned in this + /// case contains the sender, so that it may be used again. Otherwise, if + /// there is a slot available, this method will return an [`OwnedPermit`] + /// that can then be used to [`send`] on the channel with a guaranteed slot. + /// This function is similar to [`reserve_owned`] except it does not await + /// for the slot to become available. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`reserve_owned`]: Sender::reserve_owned + /// [`Arc::clone`]: std::sync::Arc::clone + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.clone().try_reserve_owned().unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// } + /// ``` + pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(self)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(self)), + } + + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(1); + /// let tx2 = tx.clone(); + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, rx3) = tokio::sync::mpsc::channel::<()>(1); + /// assert!(!tx3.same_channel(&tx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + self.chan.same_channel(&other.chan) + } + + /// Returns the current capacity of the channel. + /// + /// The capacity goes down when sending a value by calling [`send`] or by reserving capacity + /// with [`reserve`]. The capacity goes up when values are received by the [`Receiver`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel::<()>(5); + /// + /// assert_eq!(tx.capacity(), 5); + /// + /// // Making a reservation drops the capacity by one. + /// let permit = tx.reserve().await.unwrap(); + /// assert_eq!(tx.capacity(), 4); + /// + /// // Sending and receiving a value increases the capacity by one. + /// permit.send(()); + /// rx.recv().await.unwrap(); + /// assert_eq!(tx.capacity(), 5); + /// } + /// ``` + /// + /// [`send`]: Sender::send + /// [`reserve`]: Sender::reserve + pub fn capacity(&self) -> usize { + self.chan.semaphore().0.available_permits() + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl<T> Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + // If this is the last sender for this channel, wake the receiver so + // that it can be notified that the channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); + } + } +} + +impl<T> fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> OwnedPermit<T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// Unlike [`Permit::send`], this method returns the [`Sender`] from which + /// the `OwnedPermit` was reserved. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message on the permit, returning the sender. + /// let tx = permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// } + /// ``` + pub fn send(mut self, value: T) -> Sender<T> { + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + chan.send(value); + + Sender { chan } + } + + /// Releases the reserved capacity *without* sending a message, returning the + /// [`Sender`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send on the original `tx` will fail, since the `permit` + /// // has reserved all the available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Release the permit without sending a message, returning the clone + /// // of the sender. + /// let tx2 = permit.release(); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// # drop(rx); drop(tx2); + /// } + /// ``` + /// + /// [`Sender`]: Sender + pub fn release(mut self) -> Sender<T> { + use chan::Semaphore; + + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + + // Add the permit back to the semaphore + chan.semaphore().add_permit(); + Sender { chan } + } +} + +impl<T> Drop for OwnedPermit<T> { + fn drop(&mut self) { + use chan::Semaphore; + + // Are we still holding onto the sender? + if let Some(chan) = self.chan.take() { + let semaphore = chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + // If this `OwnedPermit` is holding the last sender for this + // channel, wake the receiver so that it can be notified that the + // channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + chan.wake_rx(); + } + } + + // Otherwise, do nothing. + } +} + +impl<T> fmt::Debug for OwnedPermit<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OwnedPermit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/chan.rs b/third_party/rust/tokio/src/sync/mpsc/chan.rs new file mode 100644 index 0000000000..c3007de89c --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/chan.rs @@ -0,0 +1,405 @@ +use crate::loom::cell::UnsafeCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::park::thread::CachedParkThread; +use crate::park::Park; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; +use crate::sync::notify::Notify; + +use std::fmt; +use std::process; +use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +/// Channel sender. +pub(crate) struct Tx<T, S> { + inner: Arc<Chan<T, S>>, +} + +impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx").field("inner", &self.inner).finish() + } +} + +/// Channel receiver. +pub(crate) struct Rx<T, S: Semaphore> { + inner: Arc<Chan<T, S>>, +} + +impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx").field("inner", &self.inner).finish() + } +} + +pub(crate) trait Semaphore { + fn is_idle(&self) -> bool; + + fn add_permit(&self); + + fn close(&self); + + fn is_closed(&self) -> bool; +} + +struct Chan<T, S> { + /// Notifies all tasks listening for the receiver being dropped. + notify_rx_closed: Notify, + + /// Handle to the push half of the lock-free list. + tx: list::Tx<T>, + + /// Coordinates access to channel's capacity. + semaphore: S, + + /// Receiver waker. Notified when a value is pushed into the channel. + rx_waker: AtomicWaker, + + /// Tracks the number of outstanding sender handles. + /// + /// When this drops to zero, the send half of the channel is closed. + tx_count: AtomicUsize, + + /// Only accessed by `Rx` handle. + rx_fields: UnsafeCell<RxFields<T>>, +} + +impl<T, S> fmt::Debug for Chan<T, S> +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Chan") + .field("tx", &self.tx) + .field("semaphore", &self.semaphore) + .field("rx_waker", &self.rx_waker) + .field("tx_count", &self.tx_count) + .field("rx_fields", &"...") + .finish() + } +} + +/// Fields only accessed by `Rx` handle. +struct RxFields<T> { + /// Channel receiver. This field is only accessed by the `Receiver` type. + list: list::Rx<T>, + + /// `true` if `Rx::close` is called. + rx_closed: bool, +} + +impl<T> fmt::Debug for RxFields<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("RxFields") + .field("list", &self.list) + .field("rx_closed", &self.rx_closed) + .finish() + } +} + +unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} +unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} + +pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) { + let (tx, rx) = list::channel(); + + let chan = Arc::new(Chan { + notify_rx_closed: Notify::new(), + tx, + semaphore, + rx_waker: AtomicWaker::new(), + tx_count: AtomicUsize::new(1), + rx_fields: UnsafeCell::new(RxFields { + list: rx, + rx_closed: false, + }), + }); + + (Tx::new(chan.clone()), Rx::new(chan)) +} + +// ===== impl Tx ===== + +impl<T, S> Tx<T, S> { + fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { + Tx { inner: chan } + } + + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore + } + + /// Send a message and notify the receiver. + pub(crate) fn send(&self, value: T) { + self.inner.send(value); + } + + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); + } + + /// Returns `true` if senders belong to the same channel. + pub(crate) fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<T, S: Semaphore> Tx<T, S> { + pub(crate) fn is_closed(&self) -> bool { + self.inner.semaphore.is_closed() + } + + pub(crate) async fn closed(&self) { + // In order to avoid a race condition, we first request a notification, + // **then** check whether the semaphore is closed. If the semaphore is + // closed the notification request is dropped. + let notified = self.inner.notify_rx_closed.notified(); + + if self.inner.semaphore.is_closed() { + return; + } + notified.await; + } +} + +impl<T, S> Clone for Tx<T, S> { + fn clone(&self) -> Tx<T, S> { + // Using a Relaxed ordering here is sufficient as the caller holds a + // strong ref to `self`, preventing a concurrent decrement to zero. + self.inner.tx_count.fetch_add(1, Relaxed); + + Tx { + inner: self.inner.clone(), + } + } +} + +impl<T, S> Drop for Tx<T, S> { + fn drop(&mut self) { + if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { + return; + } + + // Close the list, which sends a `Close` message + self.inner.tx.close(); + + // Notify the receiver + self.wake_rx(); + } +} + +// ===== impl Rx ===== + +impl<T, S: Semaphore> Rx<T, S> { + fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { + Rx { inner: chan } + } + + pub(crate) fn close(&mut self) { + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + if rx_fields.rx_closed { + return; + } + + rx_fields.rx_closed = true; + }); + + self.inner.semaphore.close(); + self.inner.notify_rx_closed.notify_waiters(); + } + + /// Receive the next value + pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + use super::block::Read::*; + + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.pop(&self.inner.tx) { + Some(Value(value)) => { + self.inner.semaphore.add_permit(); + coop.made_progress(); + return Ready(Some(value)); + } + Some(Closed) => { + // TODO: This check may not be required as it most + // likely can only return `true` at this point. A + // channel is closed when all tx handles are + // dropped. Dropping a tx handle releases memory, + // which ensures that if dropping the tx handle is + // visible, then all messages sent are also visible. + assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); + return Ready(None); + } + None => {} // fall through + } + }; + } + + try_recv!(); + + self.inner.rx_waker.register_by_ref(cx.waker()); + + // It is possible that a value was pushed between attempting to read + // and registering the task, so we have to check the channel a + // second time here. + try_recv!(); + + if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + coop.made_progress(); + Ready(None) + } else { + Pending + } + }) + } + + /// Try to receive the next value. + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::list::TryPopResult; + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.try_pop(&self.inner.tx) { + TryPopResult::Ok(value) => { + self.inner.semaphore.add_permit(); + return Ok(value); + } + TryPopResult::Closed => return Err(TryRecvError::Disconnected), + TryPopResult::Empty => return Err(TryRecvError::Empty), + TryPopResult::Busy => {} // fall through + } + }; + } + + try_recv!(); + + // If a previous `poll_recv` call has set a waker, we wake it here. + // This allows us to put our own CachedParkThread waker in the + // AtomicWaker slot instead. + // + // This is not a spurious wakeup to `poll_recv` since we just got a + // Busy from `try_pop`, which only happens if there are messages in + // the queue. + self.inner.rx_waker.wake(); + + // Park the thread until the problematic send has completed. + let mut park = CachedParkThread::new(); + let waker = park.unpark().into_waker(); + loop { + self.inner.rx_waker.register_by_ref(&waker); + // It is possible that the problematic send has now completed, + // so we have to check for messages again. + try_recv!(); + park.park().expect("park failed"); + } + }) + } +} + +impl<T, S: Semaphore> Drop for Rx<T, S> { + fn drop(&mut self) { + use super::block::Read::Value; + + self.close(); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) { + self.inner.semaphore.add_permit(); + } + }) + } +} + +// ===== impl Chan ===== + +impl<T, S> Chan<T, S> { + fn send(&self, value: T) { + // Push the value + self.tx.push(value); + + // Notify the rx task + self.rx_waker.wake(); + } +} + +impl<T, S> Drop for Chan<T, S> { + fn drop(&mut self) { + use super::block::Read::Value; + + // Safety: the only owner of the rx fields is Chan, and eing + // inside its own Drop means we're the last ones to touch it. + self.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {} + unsafe { rx_fields.list.free_blocks() }; + }); + } +} + +// ===== impl Semaphore for (::Semaphore, capacity) ===== + +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { + fn add_permit(&self) { + self.0.release(1) + } + + fn is_idle(&self) -> bool { + self.0.available_permits() == self.1 + } + + fn close(&self) { + self.0.close(); + } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } +} + +// ===== impl Semaphore for AtomicUsize ===== + +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::usize; + +impl Semaphore for AtomicUsize { + fn add_permit(&self) { + let prev = self.fetch_sub(2, Release); + + if prev >> 1 == 0 { + // Something went wrong + process::abort(); + } + } + + fn is_idle(&self) -> bool { + self.load(Acquire) >> 1 == 0 + } + + fn close(&self) { + self.fetch_or(1, Release); + } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/error.rs b/third_party/rust/tokio/src/sync/mpsc/error.rs new file mode 100644 index 0000000000..3fe6bac5e1 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/error.rs @@ -0,0 +1,125 @@ +//! Channel error types. + +use std::error::Error; +use std::fmt; + +/// Error returned by the `Sender`. +#[derive(Debug)] +pub struct SendError<T>(pub T); + +impl<T> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for SendError<T> {} + +// ===== TrySendError ===== + +/// This enumeration is the list of the possible error outcomes for the +/// [try_send](super::Sender::try_send) method. +#[derive(Debug, Eq, PartialEq)] +pub enum TrySendError<T> { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), +} + +impl<T: fmt::Debug> Error for TrySendError<T> {} + +impl<T> fmt::Display for TrySendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TrySendError::Full(..) => "no available capacity", + TrySendError::Closed(..) => "channel closed", + } + ) + } +} + +impl<T> From<SendError<T>> for TrySendError<T> { + fn from(src: SendError<T>) -> TrySendError<T> { + TrySendError::Closed(src.0) + } +} + +// ===== TryRecvError ===== + +/// Error returned by `try_recv`. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// This **channel** is currently empty, but the **Sender**(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The **channel**'s sending half has become disconnected, and there will + /// never be any more data received on it. + Disconnected, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), + } + } +} + +impl Error for TryRecvError {} + +// ===== RecvError ===== + +/// Error returned by `Receiver`. +#[derive(Debug)] +#[doc(hidden)] +#[deprecated(note = "This type is unused because recv returns an Option.")] +pub struct RecvError(()); + +#[allow(deprecated)] +impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +#[allow(deprecated)] +impl Error for RecvError {} + +cfg_time! { + // ===== SendTimeoutError ===== + + #[derive(Debug, Eq, PartialEq)] + /// Error returned by [`Sender::send_timeout`](super::Sender::send_timeout)]. + pub enum SendTimeoutError<T> { + /// The data could not be sent on the channel because the channel is + /// full, and the timeout to send has elapsed. + Timeout(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), + } + + impl<T: fmt::Debug> Error for SendTimeoutError<T> {} + + impl<T> fmt::Display for SendTimeoutError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + SendTimeoutError::Timeout(..) => "timed out waiting on send operation", + SendTimeoutError::Closed(..) => "channel closed", + } + ) + } + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/list.rs b/third_party/rust/tokio/src/sync/mpsc/list.rs new file mode 100644 index 0000000000..e4eeb45411 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/list.rs @@ -0,0 +1,371 @@ +//! A concurrent, lock-free, FIFO list. + +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; +use crate::sync::mpsc::block::{self, Block}; + +use std::fmt; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; + +/// List queue transmit handle. +pub(crate) struct Tx<T> { + /// Tail in the `Block` mpmc list. + block_tail: AtomicPtr<Block<T>>, + + /// Position to push the next message. This references a block and offset + /// into the block. + tail_position: AtomicUsize, +} + +/// List queue receive handle +pub(crate) struct Rx<T> { + /// Pointer to the block being processed. + head: NonNull<Block<T>>, + + /// Next slot index to process. + index: usize, + + /// Pointer to the next block pending release. + free_head: NonNull<Block<T>>, +} + +/// Return value of `Rx::try_pop`. +pub(crate) enum TryPopResult<T> { + /// Successfully popped a value. + Ok(T), + /// The channel is empty. + Empty, + /// The channel is empty and closed. + Closed, + /// The channel is not empty, but the first value is being written. + Busy, +} + +pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) { + // Create the initial block shared between the tx and rx halves. + let initial_block = Box::new(Block::new(0)); + let initial_block_ptr = Box::into_raw(initial_block); + + let tx = Tx { + block_tail: AtomicPtr::new(initial_block_ptr), + tail_position: AtomicUsize::new(0), + }; + + let head = NonNull::new(initial_block_ptr).unwrap(); + + let rx = Rx { + head, + index: 0, + free_head: head, + }; + + (tx, rx) +} + +impl<T> Tx<T> { + /// Pushes a value into the list. + pub(crate) fn push(&self, value: T) { + // First, claim a slot for the value. `Acquire` is used here to + // synchronize with the `fetch_add` in `reclaim_blocks`. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + // Load the current block and write the value + let block = self.find_block(slot_index); + + unsafe { + // Write the value to the block + block.as_ref().write(slot_index, value); + } + } + + /// Closes the send half of the list. + /// + /// Similar process as pushing a value, but instead of writing the value & + /// setting the ready flag, the TX_CLOSED flag is set on the block. + pub(crate) fn close(&self) { + // First, claim a slot for the value. This is the last slot that will be + // claimed. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + let block = self.find_block(slot_index); + + unsafe { block.as_ref().tx_close() } + } + + fn find_block(&self, slot_index: usize) -> NonNull<Block<T>> { + // The start index of the block that contains `index`. + let start_index = block::start_index(slot_index); + + // The index offset into the block + let offset = block::offset(slot_index); + + // Load the current head of the block + let mut block_ptr = self.block_tail.load(Acquire); + + let block = unsafe { &*block_ptr }; + + // Calculate the distance between the tail ptr and the target block + let distance = block.distance(start_index); + + // Decide if this call to `find_block` should attempt to update the + // `block_tail` pointer. + // + // Updating `block_tail` is not always performed in order to reduce + // contention. + // + // When set, as the routine walks the linked list, it attempts to update + // `block_tail`. If the update cannot be performed, `try_updating_tail` + // is unset. + let mut try_updating_tail = distance > offset; + + // Walk the linked list of blocks until the block with `start_index` is + // found. + loop { + let block = unsafe { &(*block_ptr) }; + + if block.is_at_index(start_index) { + return unsafe { NonNull::new_unchecked(block_ptr) }; + } + + let next_block = block + .load_next(Acquire) + // There is no allocated next block, grow the linked list. + .unwrap_or_else(|| block.grow()); + + // If the block is **not** final, then the tail pointer cannot be + // advanced any more. + try_updating_tail &= block.is_final(); + + if try_updating_tail { + // Advancing `block_tail` must happen when walking the linked + // list. `block_tail` may not advance passed any blocks that are + // not "final". At the point a block is finalized, it is unknown + // if there are any prior blocks that are unfinalized, which + // makes it impossible to advance `block_tail`. + // + // While walking the linked list, `block_tail` can be advanced + // as long as finalized blocks are traversed. + // + // Release ordering is used to ensure that any subsequent reads + // are able to see the memory pointed to by `block_tail`. + // + // Acquire is not needed as any "actual" value is not accessed. + // At this point, the linked list is walked to acquire blocks. + if self + .block_tail + .compare_exchange(block_ptr, next_block.as_ptr(), Release, Relaxed) + .is_ok() + { + // Synchronize with any senders + let tail_position = self.tail_position.fetch_add(0, Release); + + unsafe { + block.tx_release(tail_position); + } + } else { + // A concurrent sender is also working on advancing + // `block_tail` and this thread is falling behind. + // + // Stop trying to advance the tail pointer + try_updating_tail = false; + } + } + + block_ptr = next_block.as_ptr(); + + thread::yield_now(); + } + } + + pub(crate) unsafe fn reclaim_block(&self, mut block: NonNull<Block<T>>) { + // The block has been removed from the linked list and ownership + // is reclaimed. + // + // Before dropping the block, see if it can be reused by + // inserting it back at the end of the linked list. + // + // First, reset the data + block.as_mut().reclaim(); + + let mut reused = false; + + // Attempt to insert the block at the end + // + // Walk at most three times + // + let curr_ptr = self.block_tail.load(Acquire); + + // The pointer can never be null + debug_assert!(!curr_ptr.is_null()); + + let mut curr = NonNull::new_unchecked(curr_ptr); + + // TODO: Unify this logic with Block::grow + for _ in 0..3 { + match curr.as_ref().try_push(&mut block, AcqRel, Acquire) { + Ok(_) => { + reused = true; + break; + } + Err(next) => { + curr = next; + } + } + } + + if !reused { + let _ = Box::from_raw(block.as_ptr()); + } + } +} + +impl<T> fmt::Debug for Tx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx") + .field("block_tail", &self.block_tail.load(Relaxed)) + .field("tail_position", &self.tail_position.load(Relaxed)) + .finish() + } +} + +impl<T> Rx<T> { + /// Pops the next value off the queue. + pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> { + // Advance `head`, if needed + if !self.try_advancing_head() { + return None; + } + + self.reclaim_blocks(tx); + + unsafe { + let block = self.head.as_ref(); + + let ret = block.read(self.index); + + if let Some(block::Read::Value(..)) = ret { + self.index = self.index.wrapping_add(1); + } + + ret + } + } + + /// Pops the next value off the queue, detecting whether the block + /// is busy or empty on failure. + /// + /// This function exists because `Rx::pop` can return `None` even if the + /// channel's queue contains a message that has been completely written. + /// This can happen if the fully delivered message is behind another message + /// that is in the middle of being written to the block, since the channel + /// can't return the messages out of order. + pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> { + let tail_position = tx.tail_position.load(Acquire); + let result = self.pop(tx); + + match result { + Some(block::Read::Value(t)) => TryPopResult::Ok(t), + Some(block::Read::Closed) => TryPopResult::Closed, + None if tail_position == self.index => TryPopResult::Empty, + None => TryPopResult::Busy, + } + } + + /// Tries advancing the block pointer to the block referenced by `self.index`. + /// + /// Returns `true` if successful, `false` if there is no next block to load. + fn try_advancing_head(&mut self) -> bool { + let block_index = block::start_index(self.index); + + loop { + let next_block = { + let block = unsafe { self.head.as_ref() }; + + if block.is_at_index(block_index) { + return true; + } + + block.load_next(Acquire) + }; + + let next_block = match next_block { + Some(next_block) => next_block, + None => { + return false; + } + }; + + self.head = next_block; + + thread::yield_now(); + } + } + + fn reclaim_blocks(&mut self, tx: &Tx<T>) { + while self.free_head != self.head { + unsafe { + // Get a handle to the block that will be freed and update + // `free_head` to point to the next block. + let block = self.free_head; + + let observed_tail_position = block.as_ref().observed_tail_position(); + + let required_index = match observed_tail_position { + Some(i) => i, + None => return, + }; + + if required_index > self.index { + return; + } + + // We may read the next pointer with `Relaxed` ordering as it is + // guaranteed that the `reclaim_blocks` routine trails the `recv` + // routine. Any memory accessed by `reclaim_blocks` has already + // been acquired by `recv`. + let next_block = block.as_ref().load_next(Relaxed); + + // Update the free list head + self.free_head = next_block.unwrap(); + + // Push the emptied block onto the back of the queue, making it + // available to senders. + tx.reclaim_block(block); + } + + thread::yield_now(); + } + } + + /// Effectively `Drop` all the blocks. Should only be called once, when + /// the list is dropping. + pub(super) unsafe fn free_blocks(&mut self) { + debug_assert_ne!(self.free_head, NonNull::dangling()); + + let mut cur = Some(self.free_head); + + #[cfg(debug_assertions)] + { + // to trigger the debug assert above so as to catch that we + // don't call `free_blocks` more than once. + self.free_head = NonNull::dangling(); + self.head = NonNull::dangling(); + } + + while let Some(block) = cur { + cur = block.as_ref().load_next(Relaxed); + drop(Box::from_raw(block.as_ptr())); + } + } +} + +impl<T> fmt::Debug for Rx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx") + .field("head", &self.head) + .field("index", &self.index) + .field("free_head", &self.free_head) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/mod.rs b/third_party/rust/tokio/src/sync/mpsc/mod.rs new file mode 100644 index 0000000000..b1513a9da5 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/mod.rs @@ -0,0 +1,115 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A multi-producer, single-consumer queue for sending values between +//! asynchronous tasks. +//! +//! This module provides two variants of the channel: bounded and unbounded. The +//! bounded variant has a limit on the number of messages that the channel can +//! store, and if this limit is reached, trying to send another message will +//! wait until a message is received from the channel. An unbounded channel has +//! an infinite capacity, so the `send` method will always complete immediately. +//! This makes the [`UnboundedSender`] usable from both synchronous and +//! asynchronous code. +//! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and +//! [`UnboundedReceiver`] for the unbounded channel. If there is no message to read, +//! the current task will be notified when a new value is sent. [`Sender`] and +//! [`UnboundedSender`] allow sending values into the channel. If the bounded +//! channel is at capacity, the send is rejected and the task will be notified +//! when additional capacity is available. In other words, the channel provides +//! backpressure. +//! +//! +//! # Disconnection +//! +//! When all [`Sender`] handles have been dropped, it is no longer +//! possible to send values into the channel. This is considered the termination +//! event of the stream. As such, `Receiver::poll` returns `Ok(Ready(None))`. +//! +//! If the [`Receiver`] handle is dropped, then messages can no longer +//! be read out of the channel. In this case, all further attempts to send will +//! result in an error. +//! +//! # Clean Shutdown +//! +//! When the [`Receiver`] is dropped, it is possible for unprocessed messages to +//! remain in the channel. Instead, it is usually desirable to perform a "clean" +//! shutdown. To do this, the receiver first calls `close`, which will prevent +//! any further messages to be sent into the channel. Then, the receiver +//! consumes the channel to completion, at which point the receiver can be +//! dropped. +//! +//! # Communicating between sync and async code +//! +//! When you want to communicate between synchronous and asynchronous code, there +//! are two situations to consider: +//! +//! **Bounded channel**: If you need a bounded channel, you should use a bounded +//! Tokio `mpsc` channel for both directions of communication. Instead of calling +//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in +//! synchronous code you will need to use the [`blocking_send`][blocking-send] or +//! [`blocking_recv`][blocking-recv] methods. +//! +//! **Unbounded channel**: You should use the kind of channel that matches where +//! the receiver is. So for sending a message _from async to sync_, you should +//! use [the standard library unbounded channel][std-unbounded] or +//! [crossbeam][crossbeam-unbounded]. Similarly, for sending a message _from sync +//! to async_, you should use an unbounded Tokio `mpsc` channel. +//! +//! Please be aware that the above remarks were written with the `mpsc` channel +//! in mind, but they can also be generalized to other kinds of channels. In +//! general, any channel method that isn't marked async can be called anywhere, +//! including outside of the runtime. For example, sending a message on a +//! oneshot channel from outside the runtime is perfectly fine. +//! +//! # Multiple runtimes +//! +//! The mpsc channel does not care about which runtime you use it in, and can be +//! used to send messages from one runtime to another. It can also be used in +//! non-Tokio runtimes. +//! +//! There is one exception to the above: the [`send_timeout`] must be used from +//! within a Tokio runtime, however it is still not tied to one specific Tokio +//! runtime, and the sender may be moved from one Tokio runtime to another. +//! +//! [`Sender`]: crate::sync::mpsc::Sender +//! [`Receiver`]: crate::sync::mpsc::Receiver +//! [bounded-send]: crate::sync::mpsc::Sender::send() +//! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send() +//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv() +//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver +//! [`Handle::block_on`]: crate::runtime::Handle::block_on() +//! [std-unbounded]: std::sync::mpsc::channel +//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html +//! [`send_timeout`]: crate::sync::mpsc::Sender::send_timeout + +pub(super) mod block; + +mod bounded; +pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender}; + +mod chan; + +pub(super) mod list; + +mod unbounded; +pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender}; + +pub mod error; + +/// The number of values a block can contain. +/// +/// This value must be a power of 2. It also must be smaller than the number of +/// bits in `usize`. +#[cfg(all(target_pointer_width = "64", not(loom)))] +const BLOCK_CAP: usize = 32; + +#[cfg(all(not(target_pointer_width = "64"), not(loom)))] +const BLOCK_CAP: usize = 16; + +#[cfg(loom)] +const BLOCK_CAP: usize = 2; diff --git a/third_party/rust/tokio/src/sync/mpsc/unbounded.rs b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs new file mode 100644 index 0000000000..b133f9f35e --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs @@ -0,0 +1,373 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{SendError, TryRecvError}; + +use std::fmt; +use std::task::{Context, Poll}; + +/// Send values to the associated `UnboundedReceiver`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +pub struct UnboundedSender<T> { + chan: chan::Tx<T, Semaphore>, +} + +impl<T> Clone for UnboundedSender<T> { + fn clone(&self) -> Self { + UnboundedSender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for UnboundedSender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedSender") + .field("chan", &self.chan) + .finish() + } +} + +/// Receive values from the associated `UnboundedSender`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +/// +/// This receiver can be turned into a `Stream` using [`UnboundedReceiverStream`]. +/// +/// [`UnboundedReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.UnboundedReceiverStream.html +pub struct UnboundedReceiver<T> { + /// The channel receiver + chan: chan::Rx<T, Semaphore>, +} + +impl<T> fmt::Debug for UnboundedReceiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedReceiver") + .field("chan", &self.chan) + .finish() + } +} + +/// Creates an unbounded mpsc channel for communicating between asynchronous +/// tasks without backpressure. +/// +/// A `send` on this channel will always succeed as long as the receive half has +/// not been closed. If the receiver falls behind, messages will be arbitrarily +/// buffered. +/// +/// **Note** that the amount of available system memory is an implicit bound to +/// the channel. Using an `unbounded` channel has the ability of causing the +/// process to run out of memory. In this case, the process will be aborted. +pub fn unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) { + let (tx, rx) = chan::channel(AtomicUsize::new(0)); + + let tx = UnboundedSender::new(tx); + let rx = UnboundedReceiver::new(rx); + + (tx, rx) +} + +/// No capacity +type Semaphore = AtomicUsize; + +impl<T> UnboundedReceiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> UnboundedReceiver<T> { + UnboundedReceiver { chan } + } + + /// Receives the next value for this receiver. + /// + /// `None` is returned when all `Sender` halves have dropped, indicating + /// that no further values can be sent on the channel. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tokio::spawn(async move { + /// tx.send("hello").unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// tx.send("world").unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_recv(cx)).await + } + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel::<u8>(); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + + /// Closes the receiving half of a channel, without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + pub fn close(&mut self) { + self.chan.close(); + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } +} + +impl<T> UnboundedSender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> UnboundedSender<T> { + UnboundedSender { chan } + } + + /// Attempts to send a message on this `UnboundedSender` without blocking. + /// + /// This method is not marked async because sending a message to an unbounded channel + /// never requires any form of waiting. Because of this, the `send` method can be + /// used in both synchronous and asynchronous code without problems. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`UnboundedReceiver`] having been dropped, this + /// function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: UnboundedReceiver::close + /// [`UnboundedReceiver`]: UnboundedReceiver + pub fn send(&self, message: T) -> Result<(), SendError<T>> { + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); + Ok(()) + } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::unbounded_channel::<()>(); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + + /// Checks if the channel has been closed. This happens when the + /// [`UnboundedReceiver`] is dropped, or when the + /// [`UnboundedReceiver::close`] method is called. + /// + /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver + /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// let tx2 = tx.clone(); + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, rx3) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx3.same_channel(&tx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + self.chan.same_channel(&other.chan) + } +} diff --git a/third_party/rust/tokio/src/sync/mutex.rs b/third_party/rust/tokio/src/sync/mutex.rs new file mode 100644 index 0000000000..b8d5ba74e7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mutex.rs @@ -0,0 +1,967 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + +use crate::sync::batch_semaphore as semaphore; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +use std::cell::UnsafeCell; +use std::error::Error; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::{fmt, marker, mem}; + +/// An asynchronous `Mutex`-like type. +/// +/// This type acts similarly to [`std::sync::Mutex`], with two major +/// differences: [`lock`] is an async method so does not block, and the lock +/// guard is designed to be held across `.await` points. +/// +/// # Which kind of mutex should you use? +/// +/// Contrary to popular belief, it is ok and often preferred to use the ordinary +/// [`Mutex`][std] from the standard library in asynchronous code. +/// +/// The feature that the async mutex offers over the blocking mutex is the +/// ability to keep it locked across an `.await` point. This makes the async +/// mutex more expensive than the blocking mutex, so the blocking mutex should +/// be preferred in the cases where it can be used. The primary use case for the +/// async mutex is to provide shared mutable access to IO resources such as a +/// database connection. If the value behind the mutex is just data, it's +/// usually appropriate to use a blocking mutex such as the one in the standard +/// library or [`parking_lot`]. +/// +/// Note that, although the compiler will not prevent the std `Mutex` from holding +/// its guard across `.await` points in situations where the task is not movable +/// between threads, this virtually never leads to correct concurrent code in +/// practice as it can easily lead to deadlocks. +/// +/// A common pattern is to wrap the `Arc<Mutex<...>>` in a struct that provides +/// non-async methods for performing operations on the data within, and only +/// lock the mutex inside these methods. The [mini-redis] example provides an +/// illustration of this pattern. +/// +/// Additionally, when you _do_ want shared access to an IO resource, it is +/// often better to spawn a task to manage the IO resource, and to use message +/// passing to communicate with that task. +/// +/// [std]: std::sync::Mutex +/// [`parking_lot`]: https://docs.rs/parking_lot +/// [mini-redis]: https://github.com/tokio-rs/mini-redis/blob/master/src/db.rs +/// +/// # Examples: +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let data1 = Arc::new(Mutex::new(0)); +/// let data2 = Arc::clone(&data1); +/// +/// tokio::spawn(async move { +/// let mut lock = data2.lock().await; +/// *lock += 1; +/// }); +/// +/// let mut lock = data1.lock().await; +/// *lock += 1; +/// } +/// ``` +/// +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let count = Arc::new(Mutex::new(0)); +/// +/// for i in 0..5 { +/// let my_count = Arc::clone(&count); +/// tokio::spawn(async move { +/// for j in 0..10 { +/// let mut lock = my_count.lock().await; +/// *lock += 1; +/// println!("{} {} {}", i, j, lock); +/// } +/// }); +/// } +/// +/// loop { +/// if *count.lock().await >= 50 { +/// break; +/// } +/// } +/// println!("Count hit 50."); +/// } +/// ``` +/// There are a few things of note here to pay attention to in this example. +/// 1. The mutex is wrapped in an [`Arc`] to allow it to be shared across +/// threads. +/// 2. Each spawned task obtains a lock and releases it on every iteration. +/// 3. Mutation of the data protected by the Mutex is done by de-referencing +/// the obtained lock as seen on lines 12 and 19. +/// +/// Tokio's Mutex works in a simple FIFO (first in, first out) style where all +/// calls to [`lock`] complete in the order they were performed. In that way the +/// Mutex is "fair" and predictable in how it distributes the locks to inner +/// data. Locks are released and reacquired after every iteration, so basically, +/// each thread goes to the back of the line after it increments the value once. +/// Note that there's some unpredictability to the timing between when the +/// threads are started, but once they are going they alternate predictably. +/// Finally, since there is only a single valid lock at any given time, there is +/// no possibility of a race condition when mutating the inner value. +/// +/// Note that in contrast to [`std::sync::Mutex`], this implementation does not +/// poison the mutex when a thread holding the [`MutexGuard`] panics. In such a +/// case, the mutex will be unlocked. If the panic is caught, this might leave +/// the data protected by the mutex in an inconsistent state. +/// +/// [`Mutex`]: struct@Mutex +/// [`MutexGuard`]: struct@MutexGuard +/// [`Arc`]: struct@std::sync::Arc +/// [`std::sync::Mutex`]: struct@std::sync::Mutex +/// [`Send`]: trait@std::marker::Send +/// [`lock`]: method@Mutex::lock +pub struct Mutex<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + s: semaphore::Semaphore, + c: UnsafeCell<T>, +} + +/// A handle to a held `Mutex`. The guard can be held across any `.await` point +/// as it is [`Send`]. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally borrows the `Mutex`, so the mutex will not be +/// dropped while a guard exists. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +pub struct MutexGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: &'a Mutex<T>, +} + +/// An owned handle to a held `Mutex`. +/// +/// This guard is only available from a `Mutex` that is wrapped in an [`Arc`]. It +/// is identical to `MutexGuard`, except that rather than borrowing the `Mutex`, +/// it clones the `Arc`, incrementing the reference count. This means that +/// unlike `MutexGuard`, it will have the `'static` lifetime. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally keeps a reference-counted pointer to the original +/// `Mutex`, so even if the lock goes away, the guard remains valid. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +/// +/// [`Arc`]: std::sync::Arc +pub struct OwnedMutexGuard<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: Arc<Mutex<T>>, +} + +/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`]. +/// +/// This can be used to hold a subfield of the protected data. +/// +/// [`MutexGuard::map`]: method@MutexGuard::map +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct MappedMutexGuard<'a, T: ?Sized> { + s: &'a semaphore::Semaphore, + data: *mut T, + // Needed to tell the borrow checker that we are holding a `&mut T` + marker: marker::PhantomData<&'a mut T>, +} + +// As long as T: Send, it's fine to send and share Mutex<T> between threads. +// If T was not Send, sending and sharing a Mutex<T> would be bad, since you can +// access T through Mutex<T>. +unsafe impl<T> Send for Mutex<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for Mutex<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for OwnedMutexGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<'a, T> Sync for MappedMutexGuard<'a, T> where T: ?Sized + Sync + 'a {} +unsafe impl<'a, T> Send for MappedMutexGuard<'a, T> where T: ?Sized + Send + 'a {} + +/// Error returned from the [`Mutex::try_lock`], [`RwLock::try_read`] and +/// [`RwLock::try_write`] functions. +/// +/// `Mutex::try_lock` operation will only fail if the mutex is already locked. +/// +/// `RwLock::try_read` operation will only fail if the lock is currently held +/// by an exclusive writer. +/// +/// `RwLock::try_write` operation will if lock is held by any reader or by an +/// exclusive writer. +/// +/// [`Mutex::try_lock`]: Mutex::try_lock +/// [`RwLock::try_read`]: fn@super::RwLock::try_read +/// [`RwLock::try_write`]: fn@super::RwLock::try_write +#[derive(Debug)] +pub struct TryLockError(pub(super) ()); + +impl fmt::Display for TryLockError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "operation would block") + } +} + +impl Error for TryLockError {} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + fn check_static<T: 'static>() {} + fn check_static_val<T: 'static>(_t: T) {} + + check_send::<MutexGuard<'_, u32>>(); + check_send::<OwnedMutexGuard<u32>>(); + check_unpin::<Mutex<u32>>(); + check_send_sync::<Mutex<u32>>(); + check_static::<OwnedMutexGuard<u32>>(); + + let mutex = Mutex::new(1); + check_send_sync_val(mutex.lock()); + let arc_mutex = Arc::new(Mutex::new(1)); + check_send_sync_val(arc_mutex.clone().lock_owned()); + check_static_val(arc_mutex.lock_owned()); +} + +impl<T: ?Sized> Mutex<T> { + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// let lock = Mutex::new(5); + /// ``` + #[track_caller] + pub fn new(t: T) -> Self + where + T: Sized, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Mutex", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + semaphore::Semaphore::new(1) + }); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = semaphore::Semaphore::new(1); + + Self { + c: UnsafeCell::new(t), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// static LOCK: Mutex<i32> = Mutex::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(t: T) -> Self + where + T: Sized, + { + Self { + c: UnsafeCell::new(t), + s: semaphore::Semaphore::const_new(1), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock` makes you lose your place in + /// the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let mut n = mutex.lock().await; + /// *n = 2; + /// } + /// ``` + pub async fn lock(&self) -> MutexGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.acquire(), + self.resource_span.clone(), + "Mutex::lock", + "poll", + false, + ) + .await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + self.acquire().await; + + MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `Mutex`. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// let lock = mutex.lock().await; + /// + /// let mutex1 = Arc::clone(&mutex); + /// let blocking_task = tokio::task::spawn_blocking(move || { + /// // This shall block until the `lock` is released. + /// let mut n = mutex1.blocking_lock(); + /// *n = 2; + /// }); + /// + /// assert_eq!(*lock, 1); + /// // Release the lock. + /// drop(lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let n = mutex.try_lock().unwrap(); + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + crate::future::block_on(self.lock()) + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, this returns an + /// [`OwnedMutexGuard`]. + /// + /// This method is identical to [`Mutex::lock`], except that the returned + /// guard references the `Mutex` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `Mutex` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `Mutex` alive by holding an `Arc`. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mut n = mutex.clone().lock_owned().await; + /// *n = 2; + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + pub async fn lock_owned(self: Arc<Self>) -> OwnedMutexGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.acquire(), + self.resource_span.clone(), + "Mutex::lock_owned", + "poll", + false, + ) + .await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + self.acquire().await; + + OwnedMutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + async fn acquire(&self) { + self.s.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and + // we own it exclusively, which means that this can never happen. + unreachable!() + }); + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the + /// lock is currently held somewhere else. + /// + /// [`TryLockError`]: TryLockError + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.try_lock()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + /// ``` + pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + Ok(MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + Err(_) => Err(TryLockError(())), + } + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// fn main() { + /// let mut mutex = Mutex::new(1); + /// + /// let n = mutex.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock + /// is currently held somewhere else. + /// + /// This method is identical to [`Mutex::try_lock`], except that the + /// returned guard references the `Mutex` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `Mutex` must be wrapped in an `Arc` to call + /// this method, and the guard will live for the `'static` lifetime, as it + /// keeps the `Mutex` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// [`Arc`]: std::sync::Arc + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let n = mutex.clone().try_lock_owned()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + pub fn try_lock_owned(self: Arc<Self>) -> Result<OwnedMutexGuard<T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedMutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + Err(_) => Err(TryLockError(())), + } + } + + /// Consumes the mutex, returning the underlying data. + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.into_inner(); + /// assert_eq!(n, 1); + /// } + /// ``` + pub fn into_inner(self) -> T + where + T: Sized, + { + self.c.into_inner() + } +} + +impl<T> From<T> for Mutex<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T> Default for Mutex<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +impl<T: ?Sized> std::fmt::Debug for Mutex<T> +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("Mutex"); + match self.try_lock() { + Ok(inner) => d.field("data", &&*inner), + Err(_) => d.field("data", &format_args!("<locked>")), + }; + d.finish() + } +} + +// === impl MutexGuard === + +impl<'a, T: ?Sized> MutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::map(...)`. A method + /// would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::map(foo.lock().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = &this.lock.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::try_map(foo.lock().await, |f| Some(&mut f.0)) + /// .expect("should not fail"); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = &this.lock.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } + + /// Returns a reference to the original `Mutex`. + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// async fn unlock_and_relock<'l>(guard: MutexGuard<'l, u32>) -> MutexGuard<'l, u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex = MutexGuard::mutex(&guard); + /// drop(guard); + /// let guard = mutex.lock().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Mutex::new(0u32); + /// # let guard = mutex.lock().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &'a Mutex<T> { + this.lock + } +} + +impl<T: ?Sized> Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); + self.lock.s.release(1); + } +} + +impl<T: ?Sized> Deref for MutexGuard<'_, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl<T: ?Sized> DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +// === impl OwnedMutexGuard === + +impl<T: ?Sized> OwnedMutexGuard<T> { + /// Returns a reference to the original `Arc<Mutex>`. + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// + /// async fn unlock_and_relock(guard: OwnedMutexGuard<u32>) -> OwnedMutexGuard<u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex: Arc<Mutex<u32>> = OwnedMutexGuard::mutex(&guard).clone(); + /// drop(guard); + /// let guard = mutex.lock_owned().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Arc::new(Mutex::new(0u32)); + /// # let guard = mutex.lock_owned().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &Arc<Mutex<T>> { + &this.lock + } +} + +impl<T: ?Sized> Drop for OwnedMutexGuard<T> { + fn drop(&mut self) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); + self.lock.s.release(1) + } +} + +impl<T: ?Sized> Deref for OwnedMutexGuard<T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl<T: ?Sized> DerefMut for OwnedMutexGuard<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<T: ?Sized + fmt::Debug> fmt::Debug for OwnedMutexGuard<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized + fmt::Display> fmt::Display for OwnedMutexGuard<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +// === impl MappedMutexGuard === + +impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T: ?Sized> Drop for MappedMutexGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1) + } +} + +impl<'a, T: ?Sized> Deref for MappedMutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.data } + } +} + +impl<'a, T: ?Sized> DerefMut for MappedMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized + fmt::Debug> fmt::Debug for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/third_party/rust/tokio/src/sync/notify.rs b/third_party/rust/tokio/src/sync/notify.rs new file mode 100644 index 0000000000..83d0de4fbe --- /dev/null +++ b/third_party/rust/tokio/src/sync/notify.rs @@ -0,0 +1,740 @@ +// Allow `unreachable_pub` warnings when sync is not enabled +// due to the usage of `Notify` within the `rt` feature set. +// When this module is compiled with `sync` enabled we will warn on +// this lint. When `rt` is enabled we use `pub(crate)` which +// triggers this warning but it is safe to ignore in this case. +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; + +use std::cell::UnsafeCell; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; + +type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; + +/// Notifies a single task to wake up. +/// +/// `Notify` provides a basic mechanism to notify a single task of an event. +/// `Notify` itself does not carry any data. Instead, it is to be used to signal +/// another task to perform an operation. +/// +/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. +/// [`notified().await`] waits for a permit to become available, and [`notify_one()`] +/// sets a permit **if there currently are no available permits**. +/// +/// The synchronization details of `Notify` are similar to +/// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] +/// value contains a single permit. [`notified().await`] waits for the permit to +/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the +/// permit, waking a pending task if there is one. +/// +/// If `notify_one()` is called **before** `notified().await`, then the next call to +/// `notified().await` will complete immediately, consuming the permit. Any +/// subsequent calls to `notified().await` will wait for a new permit. +/// +/// If `notify_one()` is called **multiple** times before `notified().await`, only a +/// **single** permit is stored. The next call to `notified().await` will +/// complete immediately, but the one after will wait for a new permit. +/// +/// # Examples +/// +/// Basic usage. +/// +/// ``` +/// use tokio::sync::Notify; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let notify = Arc::new(Notify::new()); +/// let notify2 = notify.clone(); +/// +/// let handle = tokio::spawn(async move { +/// notify2.notified().await; +/// println!("received notification"); +/// }); +/// +/// println!("sending notification"); +/// notify.notify_one(); +/// +/// // Wait for task to receive notification. +/// handle.await.unwrap(); +/// } +/// ``` +/// +/// Unbound mpsc channel. +/// +/// ``` +/// use tokio::sync::Notify; +/// +/// use std::collections::VecDeque; +/// use std::sync::Mutex; +/// +/// struct Channel<T> { +/// values: Mutex<VecDeque<T>>, +/// notify: Notify, +/// } +/// +/// impl<T> Channel<T> { +/// pub fn send(&self, value: T) { +/// self.values.lock().unwrap() +/// .push_back(value); +/// +/// // Notify the consumer a value is available +/// self.notify.notify_one(); +/// } +/// +/// pub async fn recv(&self) -> T { +/// loop { +/// // Drain values +/// if let Some(value) = self.values.lock().unwrap().pop_front() { +/// return value; +/// } +/// +/// // Wait for values to be available +/// self.notify.notified().await; +/// } +/// } +/// } +/// ``` +/// +/// [park]: std::thread::park +/// [unpark]: std::thread::Thread::unpark +/// [`notified().await`]: Notify::notified() +/// [`notify_one()`]: Notify::notify_one() +/// [`Semaphore`]: crate::sync::Semaphore +#[derive(Debug)] +pub struct Notify { + // This uses 2 bits to store one of `EMPTY`, + // `WAITING` or `NOTIFIED`. The rest of the bits + // are used to store the number of times `notify_waiters` + // was called. + state: AtomicUsize, + waiters: Mutex<WaitList>, +} + +#[derive(Debug, Clone, Copy)] +enum NotificationType { + // Notification triggered by calling `notify_waiters` + AllWaiters, + // Notification triggered by calling `notify_one` + OneWaiter, +} + +#[derive(Debug)] +#[repr(C)] // required by `linked_list::Link` impl +struct Waiter { + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers<Waiter>, + + /// Waiting task's waker. + waker: Option<Waker>, + + /// `true` if the notification has been assigned to this waiter. + notified: Option<NotificationType>, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +/// Future returned from [`Notify::notified()`] +#[derive(Debug)] +pub struct Notified<'a> { + /// The `Notify` being received on. + notify: &'a Notify, + + /// The current state of the receiving process. + state: State, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, +} + +unsafe impl<'a> Send for Notified<'a> {} +unsafe impl<'a> Sync for Notified<'a> {} + +#[derive(Debug)] +enum State { + Init(usize), + Waiting, + Done, +} + +const NOTIFY_WAITERS_SHIFT: usize = 2; +const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; +const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; + +/// Initial "idle" state. +const EMPTY: usize = 0; + +/// One or more threads are currently waiting to be notified. +const WAITING: usize = 1; + +/// Pending notification. +const NOTIFIED: usize = 2; + +fn set_state(data: usize, state: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK) +} + +fn get_state(data: usize) -> usize { + data & STATE_MASK +} + +fn get_num_notify_waiters_calls(data: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT +} + +fn inc_num_notify_waiters_calls(data: usize) -> usize { + data + (1 << NOTIFY_WAITERS_SHIFT) +} + +fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) { + data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst); +} + +impl Notify { + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// let notify = Notify::new(); + /// ``` + pub fn new() -> Notify { + Notify { + state: AtomicUsize::new(0), + waiters: Mutex::new(LinkedList::new()), + } + } + + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// static NOTIFY: Notify = Notify::const_new(); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Notify { + Notify { + state: AtomicUsize::new(0), + waiters: Mutex::const_new(LinkedList::new()), + } + } + + /// Wait for a notification. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn notified(&self); + /// ``` + /// + /// Each `Notify` value holds a single permit. If a permit is available from + /// an earlier call to [`notify_one()`], then `notified().await` will complete + /// immediately, consuming that permit. Otherwise, `notified().await` waits + /// for a permit to be made available by the next call to `notify_one()`. + /// + /// [`notify_one()`]: Notify::notify_one + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute notifications in the order + /// they were requested. Cancelling a call to `notified` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify_one(); + /// } + /// ``` + pub fn notified(&self) -> Notified<'_> { + // we load the number of times notify_waiters + // was called and store that in our initial state + let state = self.state.load(SeqCst); + Notified { + notify: self, + state: State::Init(state >> NOTIFY_WAITERS_SHIFT), + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + notified: None, + _p: PhantomPinned, + }), + } + } + + /// Notifies a waiting task. + /// + /// If a task is currently waiting, that task is notified. Otherwise, a + /// permit is stored in this `Notify` value and the **next** call to + /// [`notified().await`] will complete immediately consuming the permit made + /// available by this call to `notify_one()`. + /// + /// At most one permit may be stored by `Notify`. Many sequential calls to + /// `notify_one` will result in a single permit being stored. The next call to + /// `notified().await` will complete immediately, but the one after that + /// will wait. + /// + /// [`notified().await`]: Notify::notified() + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify_one(); + /// } + /// ``` + // Alias for old name in 0.x + #[cfg_attr(docsrs, doc(alias = "notify"))] + pub fn notify_one(&self) { + // Load the current state + let mut curr = self.state.load(SeqCst); + + // If the state is `EMPTY`, transition to `NOTIFIED` and return. + while let EMPTY | NOTIFIED = get_state(curr) { + // The compare-exchange from `NOTIFIED` -> `NOTIFIED` is intended. A + // happens-before synchronization must happen between this atomic + // operation and a task calling `notified().await`. + let new = set_state(curr, NOTIFIED); + let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst); + + match res { + // No waiters, no further work to do + Ok(_) => return, + Err(actual) => { + curr = actual; + } + } + } + + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + curr = self.state.load(SeqCst); + + if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) { + drop(waiters); + waker.wake(); + } + } + + /// Notifies all waiting tasks. + /// + /// If a task is currently waiting, that task is notified. Unlike with + /// `notify_one()`, no permit is stored to be used by the next call to + /// `notified().await`. The purpose of this method is to notify all + /// already registered waiters. Registering for notification is done by + /// acquiring an instance of the `Notified` future via calling `notified()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// let notified1 = notify.notified(); + /// let notified2 = notify.notified(); + /// + /// let handle = tokio::spawn(async move { + /// println!("sending notifications"); + /// notify2.notify_waiters(); + /// }); + /// + /// notified1.await; + /// notified2.await; + /// println!("received notifications"); + /// } + /// ``` + pub fn notify_waiters(&self) { + let mut wakers = WakeList::new(); + + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + let curr = self.state.load(SeqCst); + + if let EMPTY | NOTIFIED = get_state(curr) { + // There are no waiting tasks. All we need to do is increment the + // number of times this method was called. + atomic_inc_num_notify_waiters_calls(&self.state); + return; + } + + // At this point, it is guaranteed that the state will not + // concurrently change, as holding the lock is required to + // transition **out** of `WAITING`. + 'outer: loop { + while wakers.can_push() { + match waiters.pop_back() { + Some(mut waiter) => { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::AllWaiters); + + if let Some(waker) = waiter.waker.take() { + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + drop(waiters); + + wakers.wake_all(); + + // Acquire the lock again. + waiters = self.waiters.lock(); + } + + // All waiters will be notified, the state must be transitioned to + // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be + // held, a `store` is sufficient. + let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY); + self.state.store(new, SeqCst); + + // Release the lock before notifying + drop(waiters); + + wakers.wake_all(); + } +} + +impl Default for Notify { + fn default() -> Notify { + Notify::new() + } +} + +fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> { + loop { + match get_state(curr) { + EMPTY | NOTIFIED => { + let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst); + + match res { + Ok(_) => return None, + Err(actual) => { + let actual_state = get_state(actual); + assert!(actual_state == EMPTY || actual_state == NOTIFIED); + state.store(set_state(actual, NOTIFIED), SeqCst); + return None; + } + } + } + WAITING => { + // At this point, it is guaranteed that the state will not + // concurrently change as holding the lock is required to + // transition **out** of `WAITING`. + // + // Get a pending waiter + let mut waiter = waiters.pop_back().unwrap(); + + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::OneWaiter); + let waker = waiter.waker.take(); + + if waiters.is_empty() { + // As this the **final** waiter in the list, the state + // must be transitioned to `EMPTY`. As transitioning + // **from** `WAITING` requires the lock to be held, a + // `store` is sufficient. + state.store(set_state(curr, EMPTY), SeqCst); + } + + return waker; + } + _ => unreachable!(), + } + } +} + +// ===== impl Notified ===== + +impl Notified<'_> { + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) { + unsafe { + // Safety: both `notify` and `state` are `Unpin`. + + is_unpin::<&Notify>(); + is_unpin::<AtomicUsize>(); + + let me = self.get_unchecked_mut(); + (me.notify, &mut me.state, &me.waiter) + } + } +} + +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + use State::*; + + let (notify, state, waiter) = self.project(); + + loop { + match *state { + Init(initial_notify_waiters_calls) => { + let curr = notify.state.load(SeqCst); + + // Optimistically try acquiring a pending notification + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); + + if res.is_ok() { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = cx.waker().clone(); + + // Acquire the lock and attempt to transition to the waiting + // state. + let mut waiters = notify.waiters.lock(); + + // Reload the state with the lock held + let mut curr = notify.state.load(SeqCst); + + // if notify_waiters has been called after the future + // was created, then we are done + if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls { + *state = Done; + return Poll::Ready(()); + } + + // Transition the state to WAITING. + loop { + match get_state(curr) { + EMPTY => { + // Transition to WAITING + let res = notify.state.compare_exchange( + set_state(curr, EMPTY), + set_state(curr, WAITING), + SeqCst, + SeqCst, + ); + + if let Err(actual) = res { + assert_eq!(get_state(actual), NOTIFIED); + curr = actual; + } else { + break; + } + } + WAITING => break, + NOTIFIED => { + // Try consuming the notification + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); + + match res { + Ok(_) => { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + Err(actual) => { + assert_eq!(get_state(actual), EMPTY); + curr = actual; + } + } + } + _ => unreachable!(), + } + } + + // Safety: called while locked. + unsafe { + (*waiter.get()).waker = Some(waker); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + + *state = Waiting; + + return Poll::Pending; + } + Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = notify.waiters.lock(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.notified.is_some() { + // Our waker has been notified. Reset the fields and + // remove it from the list. + w.waker = None; + w.notified = None; + + *state = Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + Done => { + return Poll::Ready(()); + } + } + } + } +} + +impl Drop for Notified<'_> { + fn drop(&mut self) { + use State::*; + + // Safety: The type only transitions to a "Waiting" state when pinned. + let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() }; + + // This is where we ensure safety. The `Notified` value is being + // dropped, which means we must ensure that the waiter entry is no + // longer stored in the linked list. + if let Waiting = *state { + let mut waiters = notify.waiters.lock(); + let mut notify_state = notify.state.load(SeqCst); + + // remove the entry from the list (if not already removed) + // + // safety: the waiter is only added to `waiters` by virtue of it + // being the only `LinkedList` available to the type. + unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; + + if waiters.is_empty() { + if let WAITING = get_state(notify_state) { + notify_state = set_state(notify_state, EMPTY); + notify.state.store(notify_state, SeqCst); + } + } + + // See if the node was notified but not received. In this case, if + // the notification was triggered via `notify_one`, it must be sent + // to the next waiter. + // + // Safety: with the entry removed from the linked list, there can be + // no concurrent access to the entry + if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } { + if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { + drop(waiters); + waker.wake(); + } + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + target.cast() + } +} + +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/sync/once_cell.rs b/third_party/rust/tokio/src/sync/once_cell.rs new file mode 100644 index 0000000000..d31a40e2c8 --- /dev/null +++ b/third_party/rust/tokio/src/sync/once_cell.rs @@ -0,0 +1,457 @@ +use super::{Semaphore, SemaphorePermit, TryAcquireError}; +use crate::loom::cell::UnsafeCell; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::ops::Drop; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; + +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. +/// +/// A `OnceCell` is typically used for global variables that need to be +/// initialized once on first use, but need no further changes. The `OnceCell` +/// in Tokio allows the initialization procedure to be asynchronous. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// async fn some_computation() -> u32 { +/// 1 + 1 +/// } +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// #[tokio::main] +/// async fn main() { +/// let result = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result, 2); +/// } +/// ``` +/// +/// It is often useful to write a wrapper method for accessing the value. +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// async fn get_global_integer() -> &'static u32 { +/// ONCE.get_or_init(|| async { +/// 1 + 1 +/// }).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = get_global_integer().await; +/// assert_eq!(*result, 2); +/// } +/// ``` +pub struct OnceCell<T> { + value_set: AtomicBool, + value: UnsafeCell<MaybeUninit<T>>, + semaphore: Semaphore, +} + +impl<T> Default for OnceCell<T> { + fn default() -> OnceCell<T> { + OnceCell::new() + } +} + +impl<T: fmt::Debug> fmt::Debug for OnceCell<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OnceCell") + .field("value", &self.get()) + .finish() + } +} + +impl<T: Clone> Clone for OnceCell<T> { + fn clone(&self) -> OnceCell<T> { + OnceCell::new_with(self.get().cloned()) + } +} + +impl<T: PartialEq> PartialEq for OnceCell<T> { + fn eq(&self, other: &OnceCell<T>) -> bool { + self.get() == other.get() + } +} + +impl<T: Eq> Eq for OnceCell<T> {} + +impl<T> Drop for OnceCell<T> { + fn drop(&mut self) { + if self.initialized_mut() { + unsafe { + self.value + .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); + }; + } + } +} + +impl<T> From<T> for OnceCell<T> { + fn from(value: T) -> Self { + let semaphore = Semaphore::new(0); + semaphore.close(); + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore, + } + } +} + +impl<T> OnceCell<T> { + /// Creates a new empty `OnceCell` instance. + pub fn new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::new(1), + } + } + + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. + /// + /// [`OnceCell::new`]: crate::sync::OnceCell::new + pub fn new_with(value: Option<T>) -> Self { + if let Some(v) = value { + OnceCell::from(v) + } else { + OnceCell::new() + } + } + + /// Creates a new empty `OnceCell` instance. + /// + /// Equivalent to `OnceCell::new`, except that it can be used in static + /// variables. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::OnceCell; + /// + /// static ONCE: OnceCell<u32> = OnceCell::const_new(); + /// + /// async fn get_global_integer() -> &'static u32 { + /// ONCE.get_or_init(|| async { + /// 1 + 1 + /// }).await + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let result = get_global_integer().await; + /// assert_eq!(*result, 2); + /// } + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::const_new(1), + } + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. + self.value_set.load(Ordering::Acquire) + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked(&self) -> &T { + &*self.value.with(|ptr| (*ptr).as_ptr()) + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked_mut(&mut self) -> &mut T { + &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) + } + + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. + self.value_set.store(true, Ordering::Release); + self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } + } + + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. + pub fn get(&self) -> Option<&T> { + if self.initialized() { + Some(unsafe { self.get_unchecked() }) + } else { + None + } + } + + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. + /// + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. + pub fn get_mut(&mut self) -> Option<&mut T> { + if self.initialized_mut() { + Some(unsafe { self.get_unchecked_mut() }) + } else { + None + } + } + + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. + /// + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. + /// + /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError + /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError + pub fn set(&self, value: T) -> Result<(), SetError<T>> { + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); + } + + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_init<F, Fut>(&self, f: F) -> &T + where + F: FnOnce() -> Fut, + Fut: Future<Output = T>, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { self.get_unchecked() } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } + } + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E> + where + F: FnOnce() -> Fut, + Fut: Future<Output = Result<T, E>>, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { Ok(self.get_unchecked()) } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), + } + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } + } + } + } + } + + /// Takes the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. + pub fn into_inner(mut self) -> Option<T> { + if self.initialized_mut() { + // Set to uninitialized for the destructor of `OnceCell` to work properly + *self.value_set.get_mut() = false; + Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) + } else { + None + } + } + + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. + pub fn take(&mut self) -> Option<T> { + std::mem::take(self).into_inner() + } +} + +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell<T>` to send values +// (of type T) across threads. +unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} + +// Access to OnceCell's value is guarded by the semaphore permit +// and atomic operations on `value_set`, so as long as T itself is Send +// it's safe to send it to another thread +unsafe impl<T: Send> Send for OnceCell<T> {} + +/// Errors that can be returned from [`OnceCell::set`]. +/// +/// [`OnceCell::set`]: crate::sync::OnceCell::set +#[derive(Debug, PartialEq)] +pub enum SetError<T> { + /// The cell was already initialized when [`OnceCell::set`] was called. + /// + /// [`OnceCell::set`]: crate::sync::OnceCell::set + AlreadyInitializedError(T), + + /// The cell is currently being initialized. + InitializingError(T), +} + +impl<T> fmt::Display for SetError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SetError::AlreadyInitializedError(_) => write!(f, "AlreadyInitializedError"), + SetError::InitializingError(_) => write!(f, "InitializingError"), + } + } +} + +impl<T: fmt::Debug> Error for SetError<T> {} + +impl<T> SetError<T> { + /// Whether `SetError` is `SetError::AlreadyInitializedError`. + pub fn is_already_init_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => true, + SetError::InitializingError(_) => false, + } + } + + /// Whether `SetError` is `SetError::InitializingError` + pub fn is_initializing_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => false, + SetError::InitializingError(_) => true, + } + } +} diff --git a/third_party/rust/tokio/src/sync/oneshot.rs b/third_party/rust/tokio/src/sync/oneshot.rs new file mode 100644 index 0000000000..2240074e73 --- /dev/null +++ b/third_party/rust/tokio/src/sync/oneshot.rs @@ -0,0 +1,1366 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A one-shot channel is used for sending a single message between +//! asynchronous tasks. The [`channel`] function is used to create a +//! [`Sender`] and [`Receiver`] handle pair that form the channel. +//! +//! The `Sender` handle is used by the producer to send the value. +//! The `Receiver` handle is used by the consumer to receive the value. +//! +//! Each handle can be used on separate tasks. +//! +//! Since the `send` method is not async, it can be used anywhere. This includes +//! sending between two runtimes, and using it from non-async code. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! if let Err(_) = tx.send(3) { +//! println!("the receiver dropped"); +//! } +//! }); +//! +//! match rx.await { +//! Ok(v) => println!("got = {:?}", v), +//! Err(_) => println!("the sender dropped"), +//! } +//! } +//! ``` +//! +//! If the sender is dropped without sending, the receiver will fail with +//! [`error::RecvError`]: +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel::<u32>(); +//! +//! tokio::spawn(async move { +//! drop(tx); +//! }); +//! +//! match rx.await { +//! Ok(_) => panic!("This doesn't happen"), +//! Err(_) => println!("the sender dropped"), +//! } +//! } +//! ``` +//! +//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of +//! the channel. +//! +//! ``` +//! use tokio::sync::oneshot; +//! use tokio::time::{interval, sleep, Duration}; +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread", start_paused = true)] +//! async fn main() { +//! let (send, mut recv) = oneshot::channel(); +//! let mut interval = interval(Duration::from_millis(100)); +//! +//! # let handle = +//! tokio::spawn(async move { +//! sleep(Duration::from_secs(1)).await; +//! send.send("shut down").unwrap(); +//! }); +//! +//! loop { +//! tokio::select! { +//! _ = interval.tick() => println!("Another 100ms"), +//! msg = &mut recv => { +//! println!("Got message: {}", msg.unwrap()); +//! break; +//! } +//! } +//! } +//! # handle.await.unwrap(); +//! } +//! ``` +//! +//! To use a `Sender` from a destructor, put it in an [`Option`] and call +//! [`Option::take`]. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! struct SendOnDrop { +//! sender: Option<oneshot::Sender<&'static str>>, +//! } +//! impl Drop for SendOnDrop { +//! fn drop(&mut self) { +//! if let Some(sender) = self.sender.take() { +//! // Using `let _ =` to ignore send errors. +//! let _ = sender.send("I got dropped!"); +//! } +//! } +//! } +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread")] +//! async fn main() { +//! let (send, recv) = oneshot::channel(); +//! +//! let send_on_drop = SendOnDrop { sender: Some(send) }; +//! drop(send_on_drop); +//! +//! assert_eq!(recv.await, Ok("I got dropped!")); +//! } +//! ``` + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll, Waker}; + +/// Sends a value to the associated [`Receiver`]. +/// +/// A pair of both a [`Sender`] and a [`Receiver`] are created by the +/// [`channel`](fn@channel) function. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Sender` from a destructor, put it in an [`Option`] and call +/// [`Option::take`]. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// struct SendOnDrop { +/// sender: Option<oneshot::Sender<&'static str>>, +/// } +/// impl Drop for SendOnDrop { +/// fn drop(&mut self) { +/// if let Some(sender) = self.sender.take() { +/// // Using `let _ =` to ignore send errors. +/// let _ = sender.send("I got dropped!"); +/// } +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let (send, recv) = oneshot::channel(); +/// +/// let send_on_drop = SendOnDrop { sender: Some(send) }; +/// drop(send_on_drop); +/// +/// assert_eq!(recv.await, Ok("I got dropped!")); +/// } +/// ``` +/// +/// [`Option`]: std::option::Option +/// [`Option::take`]: std::option::Option::take +#[derive(Debug)] +pub struct Sender<T> { + inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +/// Receives a value from the associated [`Sender`]. +/// +/// A pair of both a [`Sender`] and a [`Receiver`] are created by the +/// [`channel`](fn@channel) function. +/// +/// This channel has no `recv` method because the receiver itself implements the +/// [`Future`] trait. To receive a value, `.await` the `Receiver` object directly. +/// +/// [`Future`]: trait@std::future::Future +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the +/// channel. +/// +/// ``` +/// use tokio::sync::oneshot; +/// use tokio::time::{interval, sleep, Duration}; +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let (send, mut recv) = oneshot::channel(); +/// let mut interval = interval(Duration::from_millis(100)); +/// +/// # let handle = +/// tokio::spawn(async move { +/// sleep(Duration::from_secs(1)).await; +/// send.send("shut down").unwrap(); +/// }); +/// +/// loop { +/// tokio::select! { +/// _ = interval.tick() => println!("Another 100ms"), +/// msg = &mut recv => { +/// println!("Got message: {}", msg.unwrap()); +/// break; +/// } +/// } +/// } +/// # handle.await.unwrap(); +/// } +/// ``` +#[derive(Debug)] +pub struct Receiver<T> { + inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span: tracing::Span, +} + +pub mod error { + //! Oneshot error types. + + use std::fmt; + + /// Error returned by the `Future` implementation for `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub struct RecvError(pub(super) ()); + + /// Error returned by the `try_recv` function on `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub enum TryRecvError { + /// The send half of the channel has not yet sent a value. + Empty, + + /// The send half of the channel was dropped without sending a value. + Closed, + } + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} + + // ===== impl TryRecvError ===== + + impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(fmt, "channel empty"), + TryRecvError::Closed => write!(fmt, "channel closed"), + } + } + } + + impl std::error::Error for TryRecvError {} +} + +use self::error::*; + +struct Inner<T> { + /// Manages the state of the inner cell. + state: AtomicUsize, + + /// The value. This is set by `Sender` and read by `Receiver`. The state of + /// the cell is tracked by `state`. + value: UnsafeCell<Option<T>>, + + /// The task to notify when the receiver drops without consuming the value. + /// + /// ## Safety + /// + /// The `TX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. + tx_task: Task, + + /// The task to notify when the value is sent. + /// + /// ## Safety + /// + /// The `RX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. + rx_task: Task, +} + +struct Task(UnsafeCell<MaybeUninit<Waker>>); + +impl Task { + unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool { + self.with_task(|w| w.will_wake(cx.waker())) + } + + unsafe fn with_task<F, R>(&self, f: F) -> R + where + F: FnOnce(&Waker) -> R, + { + self.0.with(|ptr| { + let waker: *const Waker = (&*ptr).as_ptr(); + f(&*waker) + }) + } + + unsafe fn drop_task(&self) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.drop_in_place(); + }); + } + + unsafe fn set_task(&self, cx: &mut Context<'_>) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.write(cx.waker().clone()); + }); + } +} + +#[derive(Clone, Copy)] +struct State(usize); + +/// Creates a new one-shot channel for sending single values across asynchronous +/// tasks. +/// +/// The function returns separate "send" and "receive" handles. The `Sender` +/// handle is used by the producer to send the value. The `Receiver` handle is +/// used by the consumer to receive the value. +/// +/// Each handle can be used on separate tasks. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +#[track_caller] +pub fn channel<T>() -> (Sender<T>, Receiver<T>) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Sender|Receiver", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = false, + tx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = false, + rx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = false, + value_sent.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = false, + value_received.op = "override", + ) + }); + + resource_span + }; + + let inner = Arc::new(Inner { + state: AtomicUsize::new(State::new().as_usize()), + value: UnsafeCell::new(None), + tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), + rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), + }); + + let tx = Sender { + inner: Some(inner.clone()), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_span = resource_span + .in_scope(|| tracing::trace_span!("runtime.resource.async_op", source = "Receiver::await")); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_poll_span = + async_op_span.in_scope(|| tracing::trace_span!("runtime.resource.async_op.poll")); + + let rx = Receiver { + inner: Some(inner), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span, + }; + + (tx, rx) +} + +impl<T> Sender<T> { + /// Attempts to send a value on this channel, returning it back if it could + /// not be sent. + /// + /// This method consumes `self` as only one value may ever be sent on a oneshot + /// channel. It is not marked async because sending a message to an oneshot + /// channel never requires any form of waiting. Because of this, the `send` + /// method can be used in both synchronous and asynchronous code without + /// problems. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been deallocated. Note that a + /// return value of `Err` means that the data will never be received, but + /// a return value of `Ok` does *not* mean that the data will be received. + /// It is possible for the corresponding receiver to hang up immediately + /// after this function returns `Ok`. + /// + /// # Examples + /// + /// Send a value to another task + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// if let Err(_) = tx.send(3) { + /// println!("the receiver dropped"); + /// } + /// }); + /// + /// match rx.await { + /// Ok(v) => println!("got = {:?}", v), + /// Err(_) => println!("the sender dropped"), + /// } + /// } + /// ``` + pub fn send(mut self, t: T) -> Result<(), T> { + let inner = self.inner.take().unwrap(); + + inner.value.with_mut(|ptr| unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless the + // channel has been marked as "complete" (the `VALUE_SENT` state bit + // is set). + // That bit is only set by the sender later on in this method, and + // calling this method consumes `self`. Therefore, if it was possible to + // call this method, we know that the `VALUE_SENT` bit is unset, and + // the receiver is not currently accessing the `UnsafeCell`. + *ptr = Some(t); + }); + + if !inner.complete() { + unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless + // the channel has been marked as "complete". Calling + // `complete()` will return true if this bit is set, and false + // if it is not set. Thus, if `complete()` returned false, it is + // safe for us to access the value, because we know that the + // receiver will not. + return Err(inner.consume_value().unwrap()); + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = true, + value_sent.op = "override", + ) + }); + + Ok(()) + } + + /// Waits for the associated [`Receiver`] handle to close. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// This function is useful when paired with `select!` to abort a + /// computation when the receiver is no longer interested in the result. + /// + /// # Return + /// + /// Returns a `Future` which must be awaited on. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// Basic usage + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel::<()>(); + /// + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// tx.closed().await; + /// println!("the receiver dropped"); + /// } + /// ``` + /// + /// Paired with select + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::time::{self, Duration}; + /// + /// async fn compute() -> String { + /// // Complex computation returning a `String` + /// # "hello".to_string() + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// tokio::select! { + /// _ = tx.closed() => { + /// // The receiver dropped, no need to do any further work + /// } + /// value = compute() => { + /// // The send can fail if the channel was closed at the exact same + /// // time as when compute() finished, so just ignore the failure. + /// let _ = tx.send(value); + /// } + /// } + /// }); + /// + /// // Wait for up to 10 seconds + /// let _ = time::timeout(Duration::from_secs(10), rx).await; + /// } + /// ``` + pub async fn closed(&mut self) { + use crate::future::poll_fn; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let closed = trace::async_op( + || poll_fn(|cx| self.poll_closed(cx)), + resource_span, + "Sender::closed", + "poll_closed", + false, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let closed = poll_fn(|cx| self.poll_closed(cx)); + + closed.await + } + + /// Returns `true` if the associated [`Receiver`] handle has been dropped. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// If `true` is returned, a call to `send` will always result in an error. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// } + /// ``` + pub fn is_closed(&self) -> bool { + let inner = self.inner.as_ref().unwrap(); + + let state = State::load(&inner.state, Acquire); + state.is_closed() + } + + /// Checks whether the oneshot channel has been closed, and if not, schedules the + /// `Waker` in the provided `Context` to receive a notification when the channel is + /// closed. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly, or when the + /// [`Receiver`] value is dropped. + /// + /// Note that on multiple calls to poll, only the `Waker` from the `Context` passed + /// to the most recent call will be scheduled to receive a wakeup. + /// + /// [`Receiver`]: struct@crate::sync::oneshot::Receiver + /// [`close`]: fn@crate::sync::oneshot::Receiver::close + /// + /// # Return value + /// + /// This function returns: + /// + /// * `Poll::Pending` if the channel is still open. + /// * `Poll::Ready(())` if the channel is closed. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = oneshot::channel::<()>(); + /// + /// tokio::spawn(async move { + /// rx.close(); + /// }); + /// + /// poll_fn(|cx| tx.poll_closed(cx)).await; + /// + /// println!("the receiver dropped"); + /// } + /// ``` + pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + let inner = self.inner.as_ref().unwrap(); + + let mut state = State::load(&inner.state, Acquire); + + if state.is_closed() { + coop.made_progress(); + return Poll::Ready(()); + } + + if state.is_tx_task_set() { + let will_notify = unsafe { inner.tx_task.will_wake(cx) }; + + if !will_notify { + state = State::unset_tx_task(&inner.state); + + if state.is_closed() { + // Set the flag again so that the waker is released in drop + State::set_tx_task(&inner.state); + coop.made_progress(); + return Ready(()); + } else { + unsafe { inner.tx_task.drop_task() }; + } + } + } + + if !state.is_tx_task_set() { + // Attempt to set the task + unsafe { + inner.tx_task.set_task(cx); + } + + // Update the state + state = State::set_tx_task(&inner.state); + + if state.is_closed() { + coop.made_progress(); + return Ready(()); + } + } + + Pending + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.complete(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = true, + tx_dropped.op = "override", + ) + }); + } + } +} + +impl<T> Receiver<T> { + /// Prevents the associated [`Sender`] handle from sending a value. + /// + /// Any `send` operation which happens after calling `close` is guaranteed + /// to fail. After calling `close`, [`try_recv`] should be called to + /// receive a value if one was sent **before** the call to `close` + /// completed. + /// + /// This function is useful to perform a graceful shutdown and ensure that a + /// value will not be sent into the channel and never received. + /// + /// `close` is no-op if a message is already received or the channel + /// is already closed. + /// + /// [`Sender`]: Sender + /// [`try_recv`]: Receiver::try_recv + /// + /// # Examples + /// + /// Prevent a value from being sent + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// rx.close(); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// + /// match rx.try_recv() { + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// Receive a value sent **before** calling `close` + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(tx.send("will receive").is_ok()); + /// + /// rx.close(); + /// + /// let msg = rx.try_recv().unwrap(); + /// assert_eq!(msg, "will receive"); + /// } + /// ``` + pub fn close(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); + } + } + + /// Attempts to receive a value. + /// + /// If a pending value exists in the channel, it is returned. If no value + /// has been sent, the current task **will not** be registered for + /// future notification. + /// + /// This function is useful to call from outside the context of an + /// asynchronous task. + /// + /// # Return + /// + /// - `Ok(T)` if a value is pending in the channel. + /// - `Err(TryRecvError::Empty)` if no value has been sent yet. + /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending + /// a value. + /// + /// # Examples + /// + /// `try_recv` before a value is sent, then after. + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// match rx.try_recv() { + /// // The channel is currently empty + /// Err(TryRecvError::Empty) => {} + /// _ => unreachable!(), + /// } + /// + /// // Send a value + /// tx.send("hello").unwrap(); + /// + /// match rx.try_recv() { + /// Ok(value) => assert_eq!(value, "hello"), + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// `try_recv` when the sender dropped before sending a value + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel::<()>(); + /// + /// drop(tx); + /// + /// match rx.try_recv() { + /// // The channel will never receive a value. + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let result = if let Some(inner) = self.inner.as_ref() { + let state = State::load(&inner.state, Acquire); + + if state.is_complete() { + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. + match unsafe { inner.consume_value() } { + Some(value) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = true, + value_received.op = "override", + ) + }); + Ok(value) + } + None => Err(TryRecvError::Closed), + } + } else if state.is_closed() { + Err(TryRecvError::Closed) + } else { + // Not ready, this does not clear `inner` + return Err(TryRecvError::Empty); + } + } else { + Err(TryRecvError::Closed) + }; + + self.inner = None; + result + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel::<u8>(); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Ok(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(self) -> Result<T, RecvError> { + crate::future::block_on(self) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); + } + } +} + +impl<T> Future for Receiver<T> { + type Output = Result<T, RecvError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // If `inner` is `None`, then `poll()` has already completed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _res_span = self.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_span = self.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_poll_span = self.async_op_poll_span.clone().entered(); + + let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let res = ready!(trace_poll_op!("poll_recv", inner.poll_recv(cx)))?; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let res = ready!(inner.poll_recv(cx))?; + + res + } else { + panic!("called after complete"); + }; + + self.inner = None; + Ready(Ok(ret)) + } +} + +impl<T> Inner<T> { + fn complete(&self) -> bool { + let prev = State::set_complete(&self.state); + + if prev.is_closed() { + return false; + } + + if prev.is_rx_task_set() { + // TODO: Consume waker? + unsafe { + self.rx_task.with_task(Waker::wake_by_ref); + } + } + + true + } + + fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + // Load the state + let mut state = State::load(&self.state, Acquire); + + if state.is_complete() { + coop.made_progress(); + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else if state.is_closed() { + coop.made_progress(); + Ready(Err(RecvError(()))) + } else { + if state.is_rx_task_set() { + let will_notify = unsafe { self.rx_task.will_wake(cx) }; + + // Check if the task is still the same + if !will_notify { + // Unset the task + state = State::unset_rx_task(&self.state); + if state.is_complete() { + // Set the flag again so that the waker is released in drop + State::set_rx_task(&self.state); + + coop.made_progress(); + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. + return match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + }; + } else { + unsafe { self.rx_task.drop_task() }; + } + } + } + + if !state.is_rx_task_set() { + // Attempt to set the task + unsafe { + self.rx_task.set_task(cx); + } + + // Update the state + state = State::set_rx_task(&self.state); + + if state.is_complete() { + coop.made_progress(); + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else { + Pending + } + } else { + Pending + } + } + } + + /// Called by `Receiver` to indicate that the value will never be received. + fn close(&self) { + let prev = State::set_closed(&self.state); + + if prev.is_tx_task_set() && !prev.is_complete() { + unsafe { + self.tx_task.with_task(Waker::wake_by_ref); + } + } + } + + /// Consumes the value. This function does not check `state`. + /// + /// # Safety + /// + /// Calling this method concurrently on multiple threads will result in a + /// data race. The `VALUE_SENT` state bit is used to ensure that only the + /// sender *or* the receiver will call this method at a given point in time. + /// If `VALUE_SENT` is not set, then only the sender may call this method; + /// if it is set, then only the receiver may call this method. + unsafe fn consume_value(&self) -> Option<T> { + self.value.with_mut(|ptr| (*ptr).take()) + } +} + +unsafe impl<T: Send> Send for Inner<T> {} +unsafe impl<T: Send> Sync for Inner<T> {} + +fn mut_load(this: &mut AtomicUsize) -> usize { + this.with_mut(|v| *v) +} + +impl<T> Drop for Inner<T> { + fn drop(&mut self) { + let state = State(mut_load(&mut self.state)); + + if state.is_rx_task_set() { + unsafe { + self.rx_task.drop_task(); + } + } + + if state.is_tx_task_set() { + unsafe { + self.tx_task.drop_task(); + } + } + } +} + +impl<T: fmt::Debug> fmt::Debug for Inner<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + use std::sync::atomic::Ordering::Relaxed; + + fmt.debug_struct("Inner") + .field("state", &State::load(&self.state, Relaxed)) + .finish() + } +} + +/// Indicates that a waker for the receiving task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `rx_task` field may be uninitialized. +const RX_TASK_SET: usize = 0b00001; +/// Indicates that a value has been stored in the channel's inner `UnsafeCell`. +/// +/// # Safety +/// +/// This bit controls which side of the channel is permitted to access the +/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the +/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by +/// the sender. +const VALUE_SENT: usize = 0b00010; +const CLOSED: usize = 0b00100; + +/// Indicates that a waker for the sending task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `tx_task` field may be uninitialized. +const TX_TASK_SET: usize = 0b01000; + +impl State { + fn new() -> State { + State(0) + } + + fn is_complete(self) -> bool { + self.0 & VALUE_SENT == VALUE_SENT + } + + fn set_complete(cell: &AtomicUsize) -> State { + // This method is a compare-and-swap loop rather than a fetch-or like + // other `set_$WHATEVER` methods on `State`. This is because we must + // check if the state has been closed before setting the `VALUE_SENT` + // bit. + // + // We don't want to set both the `VALUE_SENT` bit if the `CLOSED` + // bit is already set, because `VALUE_SENT` will tell the receiver that + // it's okay to access the inner `UnsafeCell`. Immediately after calling + // `set_complete`, if the channel was closed, the sender will _also_ + // access the `UnsafeCell` to take the value back out, so if a + // `poll_recv` or `try_recv` call is occurring concurrently, both + // threads may try to access the `UnsafeCell` if we were to set the + // `VALUE_SENT` bit on a closed channel. + let mut state = cell.load(Ordering::Relaxed); + loop { + if State(state).is_closed() { + break; + } + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + match cell.compare_exchange_weak( + state, + state | VALUE_SENT, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(actual) => state = actual, + } + } + State(state) + } + + fn is_rx_task_set(self) -> bool { + self.0 & RX_TASK_SET == RX_TASK_SET + } + + fn set_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(RX_TASK_SET, AcqRel); + State(val | RX_TASK_SET) + } + + fn unset_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!RX_TASK_SET, AcqRel); + State(val & !RX_TASK_SET) + } + + fn is_closed(self) -> bool { + self.0 & CLOSED == CLOSED + } + + fn set_closed(cell: &AtomicUsize) -> State { + // Acquire because we want all later writes (attempting to poll) to be + // ordered after this. + let val = cell.fetch_or(CLOSED, Acquire); + State(val) + } + + fn set_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(TX_TASK_SET, AcqRel); + State(val | TX_TASK_SET) + } + + fn unset_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!TX_TASK_SET, AcqRel); + State(val & !TX_TASK_SET) + } + + fn is_tx_task_set(self) -> bool { + self.0 & TX_TASK_SET == TX_TASK_SET + } + + fn as_usize(self) -> usize { + self.0 + } + + fn load(cell: &AtomicUsize, order: Ordering) -> State { + let val = cell.load(order); + State(val) + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("State") + .field("is_complete", &self.is_complete()) + .field("is_closed", &self.is_closed()) + .field("is_rx_task_set", &self.is_rx_task_set()) + .field("is_tx_task_set", &self.is_tx_task_set()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock.rs b/third_party/rust/tokio/src/sync/rwlock.rs new file mode 100644 index 0000000000..b856cfc856 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock.rs @@ -0,0 +1,1078 @@ +use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; +use crate::sync::mutex::TryLockError; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use std::cell::UnsafeCell; +use std::marker; +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::sync::Arc; + +pub(crate) mod owned_read_guard; +pub(crate) mod owned_write_guard; +pub(crate) mod owned_write_guard_mapped; +pub(crate) mod read_guard; +pub(crate) mod write_guard; +pub(crate) mod write_guard_mapped; +pub(crate) use owned_read_guard::OwnedRwLockReadGuard; +pub(crate) use owned_write_guard::OwnedRwLockWriteGuard; +pub(crate) use owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; +pub(crate) use read_guard::RwLockReadGuard; +pub(crate) use write_guard::RwLockWriteGuard; +pub(crate) use write_guard_mapped::RwLockMappedWriteGuard; + +#[cfg(not(loom))] +const MAX_READS: u32 = std::u32::MAX >> 3; + +#[cfg(loom)] +const MAX_READS: u32 = 10; + +/// An asynchronous reader-writer lock. +/// +/// This type of lock allows a number of readers or at most one writer at any +/// point in time. The write portion of this lock typically allows modification +/// of the underlying data (exclusive access) and the read portion of this lock +/// typically allows for read-only access (shared access). +/// +/// In comparison, a [`Mutex`] does not distinguish between readers or writers +/// that acquire the lock, therefore causing any tasks waiting for the lock to +/// become available to yield. An `RwLock` will allow any number of readers to +/// acquire the lock as long as a writer is not holding the lock. +/// +/// The priority policy of Tokio's read-write lock is _fair_ (or +/// [_write-preferring_]), in order to ensure that readers cannot starve +/// writers. Fairness is ensured using a first-in, first-out queue for the tasks +/// awaiting the lock; if a task that wishes to acquire the write lock is at the +/// head of the queue, read locks will not be given out until the write lock has +/// been released. This is in contrast to the Rust standard library's +/// `std::sync::RwLock`, where the priority policy is dependent on the +/// operating system's implementation. +/// +/// The type parameter `T` represents the data that this lock protects. It is +/// required that `T` satisfies [`Send`] to be shared across threads. The RAII guards +/// returned from the locking methods implement [`Deref`](trait@std::ops::Deref) +/// (and [`DerefMut`](trait@std::ops::DerefMut) +/// for the `write` methods) to allow access to the content of the lock. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::RwLock; +/// +/// #[tokio::main] +/// async fn main() { +/// let lock = RwLock::new(5); +/// +/// // many reader locks can be held at once +/// { +/// let r1 = lock.read().await; +/// let r2 = lock.read().await; +/// assert_eq!(*r1, 5); +/// assert_eq!(*r2, 5); +/// } // read locks are dropped at this point +/// +/// // only one write lock may be held, however +/// { +/// let mut w = lock.write().await; +/// *w += 1; +/// assert_eq!(*w, 6); +/// } // write lock is dropped here +/// } +/// ``` +/// +/// [`Mutex`]: struct@super::Mutex +/// [`RwLock`]: struct@RwLock +/// [`RwLockReadGuard`]: struct@RwLockReadGuard +/// [`RwLockWriteGuard`]: struct@RwLockWriteGuard +/// [`Send`]: trait@std::marker::Send +/// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies +#[derive(Debug)] +pub struct RwLock<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + + // maximum number of concurrent readers + mr: u32, + + //semaphore to coordinate read and write access to T + s: Semaphore, + + //inner data T + c: UnsafeCell<T>, +} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_sync<T: Sync>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + + check_send::<RwLock<u32>>(); + check_sync::<RwLock<u32>>(); + check_unpin::<RwLock<u32>>(); + + check_send::<RwLockReadGuard<'_, u32>>(); + check_sync::<RwLockReadGuard<'_, u32>>(); + check_unpin::<RwLockReadGuard<'_, u32>>(); + + check_send::<OwnedRwLockReadGuard<u32, i32>>(); + check_sync::<OwnedRwLockReadGuard<u32, i32>>(); + check_unpin::<OwnedRwLockReadGuard<u32, i32>>(); + + check_send::<RwLockWriteGuard<'_, u32>>(); + check_sync::<RwLockWriteGuard<'_, u32>>(); + check_unpin::<RwLockWriteGuard<'_, u32>>(); + + check_send::<RwLockMappedWriteGuard<'_, u32>>(); + check_sync::<RwLockMappedWriteGuard<'_, u32>>(); + check_unpin::<RwLockMappedWriteGuard<'_, u32>>(); + + check_send::<OwnedRwLockWriteGuard<u32>>(); + check_sync::<OwnedRwLockWriteGuard<u32>>(); + check_unpin::<OwnedRwLockWriteGuard<u32>>(); + + check_send::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + check_sync::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + check_unpin::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + + let rwlock = Arc::new(RwLock::new(0)); + check_send_sync_val(rwlock.read()); + check_send_sync_val(Arc::clone(&rwlock).read_owned()); + check_send_sync_val(rwlock.write()); + check_send_sync_val(Arc::clone(&rwlock).write_owned()); +} + +// As long as T: Send + Sync, it's fine to send and share RwLock<T> between threads. +// If T were not Send, sending and sharing a RwLock<T> would be bad, since you can access T through +// RwLock<T>. +unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {} +// NB: These impls need to be explicit since we're storing a raw pointer. +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send`. +unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {} +unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} +// T is required to be `Send` because an OwnedRwLockReadGuard can be used to drop the value held in +// the RwLock, unlike RwLockReadGuard. +unsafe impl<T, U> Send for OwnedRwLockReadGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Sync, +{ +} +unsafe impl<T, U> Sync for OwnedRwLockReadGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for OwnedRwLockWriteGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for RwLockMappedWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T, U> Sync for OwnedRwLockMappedWriteGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send` - but since this is also provides mutable access, we need to +// make sure that `T` is `Send` since its value can be sent across thread +// boundaries. +unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Send for OwnedRwLockWriteGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Send for RwLockMappedWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T, U> Send for OwnedRwLockMappedWriteGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} + +impl<T: ?Sized> RwLock<T> { + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// let lock = RwLock::new(5); + /// ``` + #[track_caller] + pub fn new(value: T) -> RwLock<T> + where + T: Sized, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = MAX_READS, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(MAX_READS as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(MAX_READS as usize); + + RwLock { + mr: MAX_READS, + c: UnsafeCell::new(value), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked + /// and allows a maximum of `max_reads` concurrent readers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// let lock = RwLock::with_max_readers(5, 1024); + /// ``` + /// + /// # Panics + /// + /// Panics if `max_reads` is more than `u32::MAX >> 3`. + #[track_caller] + pub fn with_max_readers(value: T, max_reads: u32) -> RwLock<T> + where + T: Sized, + { + assert!( + max_reads <= MAX_READS, + "a RwLock may not be created with more than {} readers", + MAX_READS + ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = max_reads, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(max_reads as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(max_reads as usize); + + RwLock { + mr: max_reads, + c: UnsafeCell::new(value), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(value: T) -> RwLock<T> + where + T: Sized, + { + RwLock { + mr: MAX_READS, + c: UnsafeCell::new(value), + s: Semaphore::const_new(MAX_READS as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked + /// and allows a maximum of `max_reads` concurrent readers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_with_max_readers(5, 1024); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_with_max_readers(value: T, mut max_reads: u32) -> RwLock<T> + where + T: Sized, + { + max_reads &= MAX_READS; + RwLock { + mr: max_reads, + c: UnsafeCell::new(value), + s: Semaphore::const_new(max_reads as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// Note that under the priority policy of [`RwLock`], read locks are not + /// granted until prior write locks, to prevent starvation. Therefore + /// deadlock may occur if a read lock is held by the current task, a write + /// lock attempt is made, and then a subsequent read lock attempt is made + /// by the current task. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read` makes you lose your place in + /// the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let r = c_lock.read().await; + /// assert_eq!(*r, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard after the spawned task finishes. + /// drop(n); + /// } + /// ``` + pub async fn read(&self) -> RwLockReadGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `RwLock` with shared read access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the read access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let mut write_lock = rwlock.write().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `write_lock` is released. + /// let read_lock = rwlock.blocking_read(); + /// assert_eq!(*read_lock, 0); + /// } + /// }); + /// + /// *write_lock -= 1; + /// drop(write_lock); // release the lock. + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// assert!(rwlock.try_write().is_ok()); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_read(&self) -> RwLockReadGuard<'_, T> { + crate::future::block_on(self.read()) + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// This method is identical to [`RwLock::read`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + /// + /// Note that under the priority policy of [`RwLock`], read locks are not + /// granted until prior write locks, to prevent starvation. Therefore + /// deadlock may occur if a read lock is held by the current task, a write + /// lock attempt is made, and then a subsequent read lock attempt is made + /// by the current task. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let n = lock.read_owned().await; + /// assert_eq!(*n, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let r = c_lock.read_owned().await; + /// assert_eq!(*r, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard after the spawned task finishes. + /// drop(n); + ///} + /// ``` + pub async fn read_owned(self: Arc<Self>) -> OwnedRwLockReadGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + OwnedRwLockReadGuard { + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let v = lock.try_read().unwrap(); + /// assert_eq!(*v, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let n = c_lock.read().await; + /// assert_eq!(*n, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard when spawned task finishes. + /// drop(v); + /// } + /// ``` + pub fn try_read(&self) -> Result<RwLockReadGuard<'_, T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_read`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let v = lock.try_read_owned().unwrap(); + /// assert_eq!(*v, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let n = c_lock.read_owned().await; + /// assert_eq!(*n, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard when spawned task finishes. + /// drop(v); + /// } + /// ``` + pub fn try_read_owned(self: Arc<Self>) -> Result<OwnedRwLockReadGuard<T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedRwLockReadGuard { + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write` makes you lose your place + /// in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.write().await; + /// *n = 2; + ///} + /// ``` + pub async fn write(&self) -> RwLockWriteGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `RwLock` with exclusive write access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::{sync::RwLock}; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let read_lock = rwlock.read().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `read_lock` is released. + /// let mut write_lock = rwlock.blocking_write(); + /// *write_lock = 2; + /// } + /// }); + /// + /// assert_eq!(*read_lock, 1); + /// // Release the last outstanding read lock. + /// drop(read_lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let read_lock = rwlock.try_read().unwrap(); + /// assert_eq!(*read_lock, 2); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_write(&self) -> RwLockWriteGuard<'_, T> { + crate::future::block_on(self.write()) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// This method is identical to [`RwLock::write`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let mut n = lock.write_owned().await; + /// *n = 2; + ///} + /// ``` + pub async fn write_owned(self: Arc<Self>) -> OwnedRwLockWriteGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + OwnedRwLockWriteGuard { + permits_acquired: self.mr, + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rw = RwLock::new(1); + /// + /// let v = rw.read().await; + /// assert_eq!(*v, 1); + /// + /// assert!(rw.try_write().is_err()); + /// } + /// ``` + pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, TryLockError> { + match self.s.try_acquire(self.mr) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + Ok(RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_write`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rw = Arc::new(RwLock::new(1)); + /// + /// let v = Arc::clone(&rw).read_owned().await; + /// assert_eq!(*v, 1); + /// + /// assert!(rw.try_write_owned().is_err()); + /// } + /// ``` + pub fn try_write_owned(self: Arc<Self>) -> Result<OwnedRwLockWriteGuard<T>, TryLockError> { + match self.s.try_acquire(self.mr) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedRwLockWriteGuard { + permits_acquired: self.mr, + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `RwLock` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// fn main() { + /// let mut lock = RwLock::new(1); + /// + /// let n = lock.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + + /// Consumes the lock, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.c.into_inner() + } +} + +impl<T> From<T> for RwLock<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T: ?Sized> Default for RwLock<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs b/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs new file mode 100644 index 0000000000..27b71bd988 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs @@ -0,0 +1,170 @@ +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem; +use std::mem::ManuallyDrop; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read_owned`] method on +/// [`RwLock`]. +/// +/// [`read_owned`]: method@crate::sync::RwLock::read_owned +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *const U, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { + /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data. + /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = lock.read_owned().await; + /// let guard = OwnedRwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, V: ?Sized>(mut this: Self, f: F) -> OwnedRwLockReadGuard<T, V> + where + F: FnOnce(&U) -> &V, + { + let data = f(&*this) as *const V; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`OwnedRwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockReadGuard::try_map(..)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = lock.read_owned().await; + /// let guard = OwnedRwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, V: ?Sized>(mut this: Self, f: F) -> Result<OwnedRwLockReadGuard<T, V>, Self> + where + F: FnOnce(&U) -> Option<&V>, + { + let data = match f(&*this) { + Some(data) => data as *const V, + None => return Err(this), + }; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized, U: ?Sized> ops::Deref for OwnedRwLockReadGuard<T, U> { + type Target = U; + + fn deref(&self) -> &U { + unsafe { &*self.data } + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Debug for OwnedRwLockReadGuard<T, U> +where + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Display for OwnedRwLockReadGuard<T, U> +where + U: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockReadGuard<T, U> { + fn drop(&mut self) { + self.lock.s.release(1); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs new file mode 100644 index 0000000000..dbedab4cbb --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs @@ -0,0 +1,279 @@ +use crate::sync::rwlock::owned_read_guard::OwnedRwLockReadGuard; +use crate::sync::rwlock::owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write_owned`] method +/// on [`RwLock`]. +/// +/// [`write_owned`]: method@crate::sync::RwLock::write_owned +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct OwnedRwLockWriteGuard<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *mut T, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized> OwnedRwLockWriteGuard<T> { + /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked + /// data. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let lock = Arc::clone(&lock); + /// let mut mapped = OwnedRwLockWriteGuard::map(lock.write_owned().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> OwnedRwLockMappedWriteGuard<T, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`OwnedRwLockMappedWriteGuard`] for a component + /// of the locked data. The original guard is returned if the closure + /// returns `None`. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockWriteGuard::try_map(...)`. A method would interfere + /// with methods of the same name on the contents of the locked data. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let guard = Arc::clone(&lock).write_owned().await; + /// let mut guard = OwnedRwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<OwnedRwLockMappedWriteGuard<T, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let permits_acquired = this.permits_acquired; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Converts this `OwnedRwLockWriteGuard` into an + /// `OwnedRwLockMappedWriteGuard`. This method can be used to store a + /// non-mapped guard in a struct field that expects a mapped guard. + /// + /// This is equivalent to calling `OwnedRwLockWriteGuard::map(guard, |me| me)`. + #[inline] + pub fn into_mapped(this: Self) -> OwnedRwLockMappedWriteGuard<T> { + Self::map(this, |me| me) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.clone().write_owned().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write_owned().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + pub fn downgrade(mut self) -> OwnedRwLockReadGuard<T> { + let lock = unsafe { ManuallyDrop::take(&mut self.lock) }; + let data = self.data; + let to_release = (self.permits_acquired - 1) as usize; + + // Release all but one of the permits held by the write guard + lock.s.release(to_release); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(self); + + OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } +} + +impl<T: ?Sized> ops::Deref for OwnedRwLockWriteGuard<T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for OwnedRwLockWriteGuard<T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<T: ?Sized> fmt::Debug for OwnedRwLockWriteGuard<T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized> fmt::Display for OwnedRwLockWriteGuard<T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized> Drop for OwnedRwLockWriteGuard<T> { + fn drop(&mut self) { + self.lock.s.release(self.permits_acquired as usize); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs new file mode 100644 index 0000000000..55a24d96ac --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs @@ -0,0 +1,191 @@ +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by [mapping] an [`OwnedRwLockWriteGuard`]. It is a +/// separate type from `OwnedRwLockWriteGuard` to disallow downgrading a mapped +/// guard, since doing so can cause undefined behavior. +/// +/// [mapping]: method@crate::sync::OwnedRwLockWriteGuard::map +/// [`OwnedRwLockWriteGuard`]: struct@crate::sync::OwnedRwLockWriteGuard +pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *mut U, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { + /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked + /// data. + /// + /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed + /// in already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let lock = Arc::clone(&lock); + /// let mut mapped = OwnedRwLockWriteGuard::map(lock.write_owned().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, V: ?Sized>(mut this: Self, f: F) -> OwnedRwLockMappedWriteGuard<T, V> + where + F: FnOnce(&mut U) -> &mut V, + { + let data = f(&mut *this) as *mut V; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new `OwnedRwLockMappedWriteGuard` for a component + /// of the locked data. The original guard is returned if the closure + /// returns `None`. + /// + /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed + /// in already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockMappedWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let guard = Arc::clone(&lock).write_owned().await; + /// let mut guard = OwnedRwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, V: ?Sized>( + mut this: Self, + f: F, + ) -> Result<OwnedRwLockMappedWriteGuard<T, V>, Self> + where + F: FnOnce(&mut U) -> Option<&mut V>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut V, + None => return Err(this), + }; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized, U: ?Sized> ops::Deref for OwnedRwLockMappedWriteGuard<T, U> { + type Target = U; + + fn deref(&self) -> &U { + unsafe { &*self.data } + } +} + +impl<T: ?Sized, U: ?Sized> ops::DerefMut for OwnedRwLockMappedWriteGuard<T, U> { + fn deref_mut(&mut self) -> &mut U { + unsafe { &mut *self.data } + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Debug for OwnedRwLockMappedWriteGuard<T, U> +where + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Display for OwnedRwLockMappedWriteGuard<T, U> +where + U: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockMappedWriteGuard<T, U> { + fn drop(&mut self) { + self.lock.s.release(self.permits_acquired as usize); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/read_guard.rs b/third_party/rust/tokio/src/sync/rwlock/read_guard.rs new file mode 100644 index 0000000000..3692131992 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/read_guard.rs @@ -0,0 +1,177 @@ +use crate::sync::batch_semaphore::Semaphore; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read`] method on +/// [`RwLock`]. +/// +/// [`read`]: method@crate::sync::RwLock::read +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct RwLockReadGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) s: &'a Semaphore, + pub(super) data: *const T, + pub(super) marker: marker::PhantomData<&'a T>, +} + +impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { + /// Makes a new `RwLockReadGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let s = this.s; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the + /// same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let s = this.s; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/write_guard.rs b/third_party/rust/tokio/src/sync/rwlock/write_guard.rs new file mode 100644 index 0000000000..7cadd74c60 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/write_guard.rs @@ -0,0 +1,282 @@ +use crate::sync::batch_semaphore::Semaphore; +use crate::sync::rwlock::read_guard::RwLockReadGuard; +use crate::sync::rwlock::write_guard_mapped::RwLockMappedWriteGuard; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write`] method +/// on [`RwLock`]. +/// +/// [`write`]: method@crate::sync::RwLock::write +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct RwLockWriteGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + pub(super) s: &'a Semaphore, + pub(super) data: *mut T, + pub(super) marker: marker::PhantomData<&'a mut T>, +} + +impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { + /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockMappedWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<RwLockMappedWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Converts this `RwLockWriteGuard` into an `RwLockMappedWriteGuard`. This + /// method can be used to store a non-mapped guard in a struct field that + /// expects a mapped guard. + /// + /// This is equivalent to calling `RwLockWriteGuard::map(guard, |me| me)`. + #[inline] + pub fn into_mapped(this: Self) -> RwLockMappedWriteGuard<'a, T> { + RwLockWriteGuard::map(this, |me| me) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.write().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + /// + /// [`RwLock`]: struct@crate::sync::RwLock + pub fn downgrade(self) -> RwLockReadGuard<'a, T> { + let RwLockWriteGuard { s, data, .. } = self; + let to_release = (self.permits_acquired - 1) as usize; + // Release all but one of the permits held by the write guard + s.release(to_release); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(self); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } +} + +impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { + fn drop(&mut self) { + self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs b/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs new file mode 100644 index 0000000000..b5c644a9e8 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs @@ -0,0 +1,197 @@ +use crate::sync::batch_semaphore::Semaphore; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by [mapping] an [`RwLockWriteGuard`]. It is a +/// separate type from `RwLockWriteGuard` to disallow downgrading a mapped +/// guard, since doing so can cause undefined behavior. +/// +/// [mapping]: method@crate::sync::RwLockWriteGuard::map +/// [`RwLockWriteGuard`]: struct@crate::sync::RwLockWriteGuard +pub struct RwLockMappedWriteGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + pub(super) s: &'a Semaphore, + pub(super) data: *mut T, + pub(super) marker: marker::PhantomData<&'a mut T>, +} + +impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { + /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockMappedWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockMappedWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockMappedWriteGuard::try_map(...)`. A method would interfere + /// with methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<RwLockMappedWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized> ops::Deref for RwLockMappedWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for RwLockMappedWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockMappedWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockMappedWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockMappedWriteGuard<'a, T> { + fn drop(&mut self) { + self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/semaphore.rs b/third_party/rust/tokio/src/sync/semaphore.rs new file mode 100644 index 0000000000..860f46f399 --- /dev/null +++ b/third_party/rust/tokio/src/sync/semaphore.rs @@ -0,0 +1,644 @@ +use super::batch_semaphore as ll; // low level implementation +use super::{AcquireError, TryAcquireError}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use std::sync::Arc; + +/// Counting semaphore performing asynchronous permit acquisition. +/// +/// A semaphore maintains a set of permits. Permits are used to synchronize +/// access to a shared resource. A semaphore differs from a mutex in that it +/// can allow more than one concurrent caller to access the shared resource at a +/// time. +/// +/// When `acquire` is called and the semaphore has remaining permits, the +/// function immediately returns a permit. However, if no remaining permits are +/// available, `acquire` (asynchronously) waits until an outstanding permit is +/// dropped. At this point, the freed permit is assigned to the caller. +/// +/// This `Semaphore` is fair, which means that permits are given out in the order +/// they were requested. This fairness is also applied when `acquire_many` gets +/// involved, so if a call to `acquire_many` at the front of the queue requests +/// more permits than currently available, this can prevent a call to `acquire` +/// from completing, even if the semaphore has enough permits complete the call +/// to `acquire`. +/// +/// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`] +/// utility. +/// +/// # Examples +/// +/// Basic usage: +/// +/// ``` +/// use tokio::sync::{Semaphore, TryAcquireError}; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Semaphore::new(3); +/// +/// let a_permit = semaphore.acquire().await.unwrap(); +/// let two_permits = semaphore.acquire_many(2).await.unwrap(); +/// +/// assert_eq!(semaphore.available_permits(), 0); +/// +/// let permit_attempt = semaphore.try_acquire(); +/// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); +/// } +/// ``` +/// +/// Use [`Semaphore::acquire_owned`] to move permits across tasks: +/// +/// ``` +/// use std::sync::Arc; +/// use tokio::sync::Semaphore; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Arc::new(Semaphore::new(3)); +/// let mut join_handles = Vec::new(); +/// +/// for _ in 0..5 { +/// let permit = semaphore.clone().acquire_owned().await.unwrap(); +/// join_handles.push(tokio::spawn(async move { +/// // perform task... +/// // explicitly own `permit` in the task +/// drop(permit); +/// })); +/// } +/// +/// for handle in join_handles { +/// handle.await.unwrap(); +/// } +/// } +/// ``` +/// +/// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html +/// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned +#[derive(Debug)] +pub struct Semaphore { + /// The low level semaphore + ll_sem: ll::Semaphore, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +/// A permit from the semaphore. +/// +/// This type is created by the [`acquire`] method. +/// +/// [`acquire`]: crate::sync::Semaphore::acquire() +#[must_use] +#[derive(Debug)] +pub struct SemaphorePermit<'a> { + sem: &'a Semaphore, + permits: u32, +} + +/// An owned permit from the semaphore. +/// +/// This type is created by the [`acquire_owned`] method. +/// +/// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() +#[must_use] +#[derive(Debug)] +pub struct OwnedSemaphorePermit { + sem: Arc<Semaphore>, + permits: u32, +} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + check_unpin::<Semaphore>(); + check_unpin::<SemaphorePermit<'_>>(); + check_send_sync::<Semaphore>(); + + let semaphore = Semaphore::new(0); + check_send_sync_val(semaphore.acquire()); +} + +impl Semaphore { + /// Creates a new semaphore with the initial number of permits. + #[track_caller] + pub fn new(permits: usize) -> Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + inherits_child_attrs = true, + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let ll_sem = resource_span.in_scope(|| ll::Semaphore::new(permits)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let ll_sem = ll::Semaphore::new(permits); + + Self { + ll_sem, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new semaphore with the initial number of permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// static SEM: Semaphore = Semaphore::const_new(10); + /// ``` + /// + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(permits: usize) -> Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return Self { + ll_sem: ll::Semaphore::const_new(permits), + resource_span: tracing::Span::none(), + }; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { + ll_sem: ll::Semaphore::const_new(permits), + }; + } + + /// Returns the current number of available permits. + pub fn available_permits(&self) -> usize { + self.ll_sem.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + pub fn add_permits(&self, n: usize) { + self.ll_sem.release(n); + } + + /// Acquires a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire` makes you lose your place + /// in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// drop(permit_1); + /// assert_eq!(semaphore.available_permits(), 1); + /// } + /// ``` + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; + Ok(SemaphorePermit { + sem: self, + permits: 1, + }) + } + + /// Acquires `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permits. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(5); + /// + /// let permit = semaphore.acquire_many(3).await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 2); + /// } + /// ``` + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many", + "poll", + true, + ) + .await?; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + self.ll_sem.acquire(n).await?; + + Ok(SemaphorePermit { + sem: self, + permits: n, + }) + } + + /// Tries to acquire a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, + /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(SemaphorePermit { + sem: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. + /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(4); + /// + /// let permit_1 = semaphore.try_acquire_many(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> { + match self.ll_sem.try_acquire(n) { + Ok(_) => Ok(SemaphorePermit { + sem: self, + permits: n, + }), + Err(e) => Err(e), + } + } + + /// Acquires a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(3)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_owned().await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; + Ok(OwnedSemaphorePermit { + sem: self, + permits: 1, + }) + } + + /// Acquires `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many_owned` makes you lose + /// your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(10)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_many_owned( + self: Arc<Self>, + n: u32, + ) -> Result<OwnedSemaphorePermit, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(n); + + inner.await?; + Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }) + } + + /// Tries to acquire a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(2)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire_owned(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(4)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_many_owned(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many_owned(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub fn try_acquire_many_owned( + self: Arc<Self>, + n: u32, + ) -> Result<OwnedSemaphorePermit, TryAcquireError> { + match self.ll_sem.try_acquire(n) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }), + Err(e) => Err(e), + } + } + + /// Closes the semaphore. + /// + /// This prevents the semaphore from issuing new permits and notifies all pending waiters. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// use std::sync::Arc; + /// use tokio::sync::TryAcquireError; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(1)); + /// let semaphore2 = semaphore.clone(); + /// + /// tokio::spawn(async move { + /// let permit = semaphore.acquire_many(2).await; + /// assert!(permit.is_err()); + /// println!("waiter received error"); + /// }); + /// + /// println!("closing semaphore"); + /// semaphore2.close(); + /// + /// // Cannot obtain more permits + /// assert_eq!(semaphore2.try_acquire().err(), Some(TryAcquireError::Closed)) + /// } + /// ``` + pub fn close(&self) { + self.ll_sem.close(); + } + + /// Returns true if the semaphore is closed + pub fn is_closed(&self) -> bool { + self.ll_sem.is_closed() + } +} + +impl<'a> SemaphorePermit<'a> { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + +impl OwnedSemaphorePermit { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + +impl<'a> Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/third_party/rust/tokio/src/sync/task/atomic_waker.rs b/third_party/rust/tokio/src/sync/task/atomic_waker.rs new file mode 100644 index 0000000000..13aba35448 --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/atomic_waker.rs @@ -0,0 +1,382 @@ +#![cfg_attr(any(loom, not(feature = "sync")), allow(dead_code, unreachable_pub))] + +use crate::loom::cell::UnsafeCell; +use crate::loom::hint; +use crate::loom::sync::atomic::AtomicUsize; + +use std::fmt; +use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::Waker; + +/// A synchronization primitive for task waking. +/// +/// `AtomicWaker` will coordinate concurrent wakes with the consumer +/// potentially "waking" the underlying task. This is useful in scenarios +/// where a computation completes in another thread and wants to wake the +/// consumer, but the consumer is in the process of being migrated to a new +/// logical task. +/// +/// Consumers should call `register` before checking the result of a computation +/// and producers should call `wake` after producing the computation (this +/// differs from the usual `thread::park` pattern). It is also permitted for +/// `wake` to be called **before** `register`. This results in a no-op. +/// +/// A single `AtomicWaker` may be reused for any number of calls to `register` or +/// `wake`. +pub(crate) struct AtomicWaker { + state: AtomicUsize, + waker: UnsafeCell<Option<Waker>>, +} + +impl RefUnwindSafe for AtomicWaker {} +impl UnwindSafe for AtomicWaker {} + +// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell +// stores a `Waker` value produced by calls to `register` and many threads can +// race to take the waker by calling `wake`. +// +// If a new `Waker` instance is produced by calling `register` before an existing +// one is consumed, then the existing one is overwritten. +// +// While `AtomicWaker` is single-producer, the implementation ensures memory +// safety. In the event of concurrent calls to `register`, there will be a +// single winner whose waker will get stored in the cell. The losers will not +// have their tasks woken. As such, callers should ensure to add synchronization +// to calls to `register`. +// +// The implementation uses a single `AtomicUsize` value to coordinate access to +// the `Waker` cell. There are two bits that are operated on independently. These +// are represented by `REGISTERING` and `WAKING`. +// +// The `REGISTERING` bit is set when a producer enters the critical section. The +// `WAKING` bit is set when a consumer enters the critical section. Neither +// bit being set is represented by `WAITING`. +// +// A thread obtains an exclusive lock on the waker cell by transitioning the +// state from `WAITING` to `REGISTERING` or `WAKING`, depending on the +// operation the thread wishes to perform. When this transition is made, it is +// guaranteed that no other thread will access the waker cell. +// +// # Registering +// +// On a call to `register`, an attempt to transition the state from WAITING to +// REGISTERING is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread sets the waker cell to the waker +// provided as an argument. Then it attempts to transition the state back from +// `REGISTERING` -> `WAITING`. +// +// If this transition is successful, then the registering process is complete +// and the next call to `wake` will observe the waker. +// +// If the transition fails, then there was a concurrent call to `wake` that +// was unable to access the waker cell (due to the registering thread holding the +// lock). To handle this, the registering thread removes the waker it just set +// from the cell and calls `wake` on it. This call to wake represents the +// attempt to wake by the other thread (that set the `WAKING` bit). The +// state is then transitioned from `REGISTERING | WAKING` back to `WAITING`. +// This transition must succeed because, at this point, the state cannot be +// transitioned by another thread. +// +// # Waking +// +// On a call to `wake`, an attempt to transition the state from `WAITING` to +// `WAKING` is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread takes ownership of the current value +// in the waker cell, and calls `wake` on it. The state is then transitioned +// back to `WAITING`. This transition must succeed as, at this point, the state +// cannot be transitioned by another thread. +// +// If the thread is unable to obtain the lock, the `WAKING` bit is still set. +// This is because it has either been set by the current thread but the previous +// value included the `REGISTERING` bit **or** a concurrent thread is in the +// `WAKING` critical section. Either way, no action must be taken. +// +// If the current thread is the only concurrent call to `wake` and another +// thread is in the `register` critical section, when the other thread **exits** +// the `register` critical section, it will observe the `WAKING` bit and +// handle the waker itself. +// +// If another thread is in the `waker` critical section, then it will handle +// waking the caller task. +// +// # A potential race (is safely handled). +// +// Imagine the following situation: +// +// * Thread A obtains the `wake` lock and wakes a task. +// +// * Before thread A releases the `wake` lock, the woken task is scheduled. +// +// * Thread B attempts to wake the task. In theory this should result in the +// task being woken, but it cannot because thread A still holds the wake +// lock. +// +// This case is handled by requiring users of `AtomicWaker` to call `register` +// **before** attempting to observe the application state change that resulted +// in the task being woken. The wakers also change the application state +// before calling wake. +// +// Because of this, the task will do one of two things. +// +// 1) Observe the application state change that Thread B is waking on. In +// this case, it is OK for Thread B's wake to be lost. +// +// 2) Call register before attempting to observe the application state. Since +// Thread A still holds the `wake` lock, the call to `register` will result +// in the task waking itself and get scheduled again. + +/// Idle state. +const WAITING: usize = 0; + +/// A new waker value is being registered with the `AtomicWaker` cell. +const REGISTERING: usize = 0b01; + +/// The task currently registered with the `AtomicWaker` cell is being woken. +const WAKING: usize = 0b10; + +impl AtomicWaker { + /// Create an `AtomicWaker` + pub(crate) fn new() -> AtomicWaker { + AtomicWaker { + state: AtomicUsize::new(WAITING), + waker: UnsafeCell::new(None), + } + } + + /* + /// Registers the current waker to be notified on calls to `wake`. + pub(crate) fn register(&self, waker: Waker) { + self.do_register(waker); + } + */ + + /// Registers the provided waker to be notified on calls to `wake`. + /// + /// The new waker will take place of any previous wakers that were registered + /// by previous calls to `register`. Any calls to `wake` that happen after + /// a call to `register` (as defined by the memory ordering rules), will + /// wake the `register` caller's task. + /// + /// It is safe to call `register` with multiple other threads concurrently + /// calling `wake`. This will result in the `register` caller's current + /// task being woken once. + /// + /// This function is safe to call concurrently, but this is generally a bad + /// idea. Concurrent calls to `register` will attempt to register different + /// tasks to be woken. One of the callers will win and have its task set, + /// but there is no guarantee as to which caller will succeed. + pub(crate) fn register_by_ref(&self, waker: &Waker) { + self.do_register(waker); + } + + fn do_register<W>(&self, waker: W) + where + W: WakerRef, + { + fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> { + std::panic::catch_unwind(AssertUnwindSafe(f)) + } + + match self + .state + .compare_exchange(WAITING, REGISTERING, Acquire, Acquire) + .unwrap_or_else(|x| x) + { + WAITING => { + unsafe { + // If `into_waker` panics (because it's code outside of + // AtomicWaker) we need to prime a guard that is called on + // unwind to restore the waker to a WAITING state. Otherwise + // any future calls to register will incorrectly be stuck + // believing it's being updated by someone else. + let new_waker_or_panic = catch_unwind(move || waker.into_waker()); + + // Set the field to contain the new waker, or if + // `into_waker` panicked, leave the old value. + let mut maybe_panic = None; + let mut old_waker = None; + match new_waker_or_panic { + Ok(new_waker) => { + old_waker = self.waker.with_mut(|t| (*t).take()); + self.waker.with_mut(|t| *t = Some(new_waker)); + } + Err(panic) => maybe_panic = Some(panic), + } + + // Release the lock. If the state transitioned to include + // the `WAKING` bit, this means that a wake has been + // called concurrently, so we have to remove the waker and + // wake it.` + // + // Start by assuming that the state is `REGISTERING` as this + // is what we jut set it to. + let res = self + .state + .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); + + match res { + Ok(_) => { + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + let _ = catch_unwind(move || { + drop(old_waker); + }); + } + Err(actual) => { + // This branch can only be reached if a + // concurrent thread called `wake`. In this + // case, `actual` **must** be `REGISTERING | + // WAKING`. + debug_assert_eq!(actual, REGISTERING | WAKING); + + // Take the waker to wake once the atomic operation has + // completed. + let mut waker = self.waker.with_mut(|t| (*t).take()); + + // Just swap, because no one could change state + // while state == `Registering | `Waking` + self.state.swap(WAITING, AcqRel); + + // If `into_waker` panicked, then the waker in the + // waker slot is actually the old waker. + if maybe_panic.is_some() { + old_waker = waker.take(); + } + + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + if let Some(old_waker) = old_waker { + let _ = catch_unwind(move || { + old_waker.wake(); + }); + } + + // The atomic swap was complete, now wake the waker + // and return. + // + // If this panics, we end up in a consumed state and + // return the panic to the caller. + if let Some(waker) = waker { + debug_assert!(maybe_panic.is_none()); + waker.wake(); + } + } + } + + if let Some(panic) = maybe_panic { + // If `into_waker` panicked, return the panic to the caller. + resume_unwind(panic); + } + } + } + WAKING => { + // Currently in the process of waking the task, i.e., + // `wake` is currently being called on the old waker. + // So, we call wake on the new waker. + // + // If this panics, someone else is responsible for restoring the + // state of the waker. + waker.wake(); + + // This is equivalent to a spin lock, so use a spin hint. + hint::spin_loop(); + } + state => { + // In this case, a concurrent thread is holding the + // "registering" lock. This probably indicates a bug in the + // caller's code as racing to call `register` doesn't make much + // sense. + // + // We just want to maintain memory safety. It is ok to drop the + // call to `register`. + debug_assert!(state == REGISTERING || state == REGISTERING | WAKING); + } + } + } + + /// Wakes the task that last called `register`. + /// + /// If `register` has not been called yet, then this does nothing. + pub(crate) fn wake(&self) { + if let Some(waker) = self.take_waker() { + // If wake panics, we've consumed the waker which is a legitimate + // outcome. + waker.wake(); + } + } + + /// Attempts to take the `Waker` value out of the `AtomicWaker` with the + /// intention that the caller will wake the task later. + pub(crate) fn take_waker(&self) -> Option<Waker> { + // AcqRel ordering is used in order to acquire the value of the `waker` + // cell as well as to establish a `release` ordering with whatever + // memory the `AtomicWaker` is associated with. + match self.state.fetch_or(WAKING, AcqRel) { + WAITING => { + // The waking lock has been acquired. + let waker = unsafe { self.waker.with_mut(|t| (*t).take()) }; + + // Release the lock + self.state.fetch_and(!WAKING, Release); + + waker + } + state => { + // There is a concurrent thread currently updating the + // associated waker. + // + // Nothing more to do as the `WAKING` bit has been set. It + // doesn't matter if there are concurrent registering threads or + // not. + // + debug_assert!( + state == REGISTERING || state == REGISTERING | WAKING || state == WAKING + ); + None + } + } + } +} + +impl Default for AtomicWaker { + fn default() -> Self { + AtomicWaker::new() + } +} + +impl fmt::Debug for AtomicWaker { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "AtomicWaker") + } +} + +unsafe impl Send for AtomicWaker {} +unsafe impl Sync for AtomicWaker {} + +trait WakerRef { + fn wake(self); + fn into_waker(self) -> Waker; +} + +impl WakerRef for Waker { + fn wake(self) { + self.wake() + } + + fn into_waker(self) -> Waker { + self + } +} + +impl WakerRef for &Waker { + fn wake(self) { + self.wake_by_ref() + } + + fn into_waker(self) -> Waker { + self.clone() + } +} diff --git a/third_party/rust/tokio/src/sync/task/mod.rs b/third_party/rust/tokio/src/sync/task/mod.rs new file mode 100644 index 0000000000..a6bc6ed06e --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/mod.rs @@ -0,0 +1,4 @@ +//! Thread-safe task notification primitives. + +mod atomic_waker; +pub(crate) use self::atomic_waker::AtomicWaker; diff --git a/third_party/rust/tokio/src/sync/tests/atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs new file mode 100644 index 0000000000..ec13cbd658 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs @@ -0,0 +1,77 @@ +use crate::sync::AtomicWaker; +use tokio_test::task; + +use std::task::Waker; + +trait AssertSend: Send {} +trait AssertSync: Send {} + +impl AssertSend for AtomicWaker {} +impl AssertSync for AtomicWaker {} + +impl AssertSend for Waker {} +impl AssertSync for Waker {} + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn basic_usage() { + let mut waker = task::spawn(AtomicWaker::new()); + + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + waker.wake(); + + assert!(waker.is_woken()); +} + +#[test] +fn wake_without_register() { + let mut waker = task::spawn(AtomicWaker::new()); + waker.wake(); + + // Registering should not result in a notification + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + + assert!(!waker.is_woken()); +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] // wasm currently doesn't support unwinding +fn atomic_waker_panic_safe() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| panic!("clone"), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) }; + + let atomic_waker = AtomicWaker::new(); + + let panicking = panic::AssertUnwindSafe(&panicking); + + let result = panic::catch_unwind(|| { + let panic::AssertUnwindSafe(panicking) = panicking; + atomic_waker.register_by_ref(panicking); + }); + + assert!(result.is_err()); + assert!(atomic_waker.take_waker().is_none()); + + atomic_waker.register_by_ref(&nonpanicking); + assert!(atomic_waker.take_waker().is_some()); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs new file mode 100644 index 0000000000..f8bae65d13 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs @@ -0,0 +1,100 @@ +use crate::sync::task::AtomicWaker; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; + +struct Chan { + num: AtomicUsize, + task: AtomicWaker, +} + +#[test] +fn basic_notification() { + const NUM_NOTIFY: usize = 2; + + loom::model(|| { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} + +#[test] +fn test_panicky_waker() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = + RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ()); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + + // If you're working with this test (and I sure hope you never have to!), + // uncomment the following section because there will be a lot of panics + // which would otherwise log. + // + // We can't however leaved it uncommented, because it's global. + // panic::set_hook(Box::new(|_| ())); + + const NUM_NOTIFY: usize = 2; + + loom::model(move || { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + // Note: this panic should have no effect on the overall state of the + // waker and it should proceed as normal. + // + // A thread above might race to flag a wakeup, and a WAKING state will + // be preserved if this expected panic races with that so the below + // procedure should be allowed to continue uninterrupted. + let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking)); + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs new file mode 100644 index 0000000000..039b01bf43 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs @@ -0,0 +1,207 @@ +use crate::sync::broadcast; +use crate::sync::broadcast::error::RecvError::{Closed, Lagged}; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; +use tokio_test::{assert_err, assert_ok}; + +#[test] +fn broadcast_send() { + loom::model(|| { + let (tx1, mut rx) = broadcast::channel(2); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + assert_ok!(tx1.send("one")); + assert_ok!(tx1.send("two")); + assert_ok!(tx1.send("three")); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + assert_ok!(tx2.send("eins")); + assert_ok!(tx2.send("zwei")); + assert_ok!(tx2.send("drei")); + }); + }); + + block_on(async { + let mut num = 0; + loop { + match rx.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + assert_eq!(num, 6); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +// An `Arc` is used as the value in order to detect memory leaks. +#[test] +fn broadcast_two() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel::<Arc<&'static str>>(16); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx2.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + assert_ok!(tx.send(Arc::new("hello"))); + assert_ok!(tx.send(Arc::new("world"))); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn broadcast_wrap() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(2); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx1.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx2.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn drop_rx() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(16); + let rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "one"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "two"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "three"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + drop(rx2); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn drop_multiple_rx_with_overflow() { + loom::model(move || { + // It is essential to have multiple senders and receivers in this test case. + let (tx, mut rx) = broadcast::channel(1); + let _rx2 = tx.subscribe(); + + let _ = tx.send(()); + let tx2 = tx.clone(); + let th1 = thread::spawn(move || { + block_on(async { + for _ in 0..100 { + let _ = tx2.send(()); + } + }); + }); + let _ = tx.send(()); + + let th2 = thread::spawn(move || { + block_on(async { while let Ok(_) = rx.recv().await {} }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_list.rs b/third_party/rust/tokio/src/sync/tests/loom_list.rs new file mode 100644 index 0000000000..4067f865ce --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_list.rs @@ -0,0 +1,48 @@ +use crate::sync::mpsc::list; + +use loom::thread; +use std::sync::Arc; + +#[test] +fn smoke() { + use crate::sync::mpsc::block::Read::*; + + const NUM_TX: usize = 2; + const NUM_MSG: usize = 2; + + loom::model(|| { + let (tx, mut rx) = list::channel(); + let tx = Arc::new(tx); + + for th in 0..NUM_TX { + let tx = tx.clone(); + + thread::spawn(move || { + for i in 0..NUM_MSG { + tx.push((th, i)); + } + }); + } + + let mut next = vec![0; NUM_TX]; + + loop { + match rx.pop(&tx) { + Some(Value((th, v))) => { + assert_eq!(v, next[th]); + next[th] += 1; + + if next.iter().all(|&i| i == NUM_MSG) { + break; + } + } + Some(Closed) => { + panic!(); + } + None => { + thread::yield_now(); + } + } + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs new file mode 100644 index 0000000000..f165e7076e --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs @@ -0,0 +1,190 @@ +use crate::sync::mpsc; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; +use tokio_test::assert_ok; + +#[test] +fn closing_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::channel(16); + + thread::spawn(move || { + tx.try_send(()).unwrap(); + drop(tx); + }); + + let v = block_on(rx.recv()); + assert!(v.is_some()); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn closing_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel(); + + thread::spawn(move || { + tx.send(()).unwrap(); + drop(tx); + }); + + let v = block_on(rx.recv()); + assert!(v.is_some()); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn closing_bounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::channel::<()>(16); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn closing_and_sending() { + loom::model(|| { + let (tx1, mut rx) = mpsc::channel::<()>(16); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + tx1.try_send(()).unwrap(); + }); + + let th2 = thread::spawn(move || { + block_on(tx2.closed()); + }); + + let th3 = thread::spawn(move || { + let v = block_on(rx.recv()); + assert!(v.is_some()); + drop(rx); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn closing_unbounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::unbounded_channel::<()>(); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn dropping_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::channel::<()>(16); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn dropping_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel::<()>(); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn try_recv() { + loom::model(|| { + use crate::sync::{mpsc, Semaphore}; + use loom::sync::{Arc, Mutex}; + + const PERMITS: usize = 2; + const TASKS: usize = 2; + const CYCLES: usize = 1; + + struct Context { + sem: Arc<Semaphore>, + tx: mpsc::Sender<()>, + rx: Mutex<mpsc::Receiver<()>>, + } + + fn run(ctx: &Context) { + block_on(async { + let permit = ctx.sem.acquire().await; + assert_ok!(ctx.rx.lock().unwrap().try_recv()); + crate::task::yield_now().await; + assert_ok!(ctx.tx.clone().try_send(())); + drop(permit); + }); + } + + let (tx, rx) = mpsc::channel(PERMITS); + let sem = Arc::new(Semaphore::new(PERMITS)); + let ctx = Arc::new(Context { + sem, + tx, + rx: Mutex::new(rx), + }); + + for _ in 0..PERMITS { + assert_ok!(ctx.tx.clone().try_send(())); + } + + let mut ths = Vec::new(); + + for _ in 0..TASKS { + let ctx = ctx.clone(); + + ths.push(thread::spawn(move || { + run(&ctx); + })); + } + + run(&ctx); + + for th in ths { + th.join().unwrap(); + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_notify.rs b/third_party/rust/tokio/src/sync/tests/loom_notify.rs new file mode 100644 index 0000000000..d484a75817 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_notify.rs @@ -0,0 +1,140 @@ +use crate::sync::Notify; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; + +#[test] +fn notify_one() { + loom::model(|| { + let tx = Arc::new(Notify::new()); + let rx = tx.clone(); + + let th = thread::spawn(move || { + block_on(async { + rx.notified().await; + }); + }); + + tx.notify_one(); + th.join().unwrap(); + }); +} + +#[test] +fn notify_waiters() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx = notify.clone(); + let notified1 = notify.notified(); + let notified2 = notify.notified(); + + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + block_on(async { + notified1.await; + notified2.await; + }); + + th.join().unwrap(); + }); +} + +#[test] +fn notify_waiters_and_one() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx1 = notify.clone(); + let tx2 = notify.clone(); + + let th1 = thread::spawn(move || { + tx1.notify_waiters(); + }); + + let th2 = thread::spawn(move || { + tx2.notify_one(); + }); + + let th3 = thread::spawn(move || { + let notified = notify.notified(); + + block_on(async { + notified.await; + }); + }); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); + }); +} + +#[test] +fn notify_multi() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + + let mut ths = vec![]; + + for _ in 0..2 { + let notify = notify.clone(); + + ths.push(thread::spawn(move || { + block_on(async { + notify.notified().await; + notify.notify_one(); + }) + })); + } + + notify.notify_one(); + + for th in ths.drain(..) { + th.join().unwrap(); + } + + block_on(async { + notify.notified().await; + }); + }); +} + +#[test] +fn notify_drop() { + use crate::future::poll_fn; + use std::future::Future; + use std::task::Poll; + + loom::model(|| { + let notify = Arc::new(Notify::new()); + let rx1 = notify.clone(); + let rx2 = notify.clone(); + + let th1 = thread::spawn(move || { + let mut recv = Box::pin(rx1.notified()); + + block_on(poll_fn(|cx| { + if recv.as_mut().poll(cx).is_ready() { + rx1.notify_one(); + } + Poll::Ready(()) + })); + }); + + let th2 = thread::spawn(move || { + block_on(async { + rx2.notified().await; + // Trigger second notification + rx2.notify_one(); + rx2.notified().await; + }); + }); + + notify.notify_one(); + + th1.join().unwrap(); + th2.join().unwrap(); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs new file mode 100644 index 0000000000..c5f7972079 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs @@ -0,0 +1,140 @@ +use crate::sync::oneshot; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::thread; +use std::task::Poll::{Pending, Ready}; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + }); +} + +#[test] +fn changing_rx_task() { + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let rx = thread::spawn(move || { + let ready = block_on(poll_fn(|cx| match Pin::new(&mut rx).poll(cx) { + Ready(Ok(value)) => { + assert_eq!(1, value); + Ready(true) + } + Ready(Err(_)) => unimplemented!(), + Pending => Ready(false), + })); + + if ready { + None + } else { + Some(rx) + } + }) + .join() + .unwrap(); + + if let Some(rx) = rx { + // Previous task parked, use a new task... + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + } + }); +} + +#[test] +fn try_recv_close() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + thread::spawn(move || { + let _ = tx.send(()); + }); + + rx.close(); + let _ = rx.try_recv(); + }) +} + +#[test] +fn recv_closed() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + let _ = tx.send(1); + }); + + rx.close(); + let _ = block_on(rx); + }); +} + +// TODO: Move this into `oneshot` proper. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +struct OnClose<'a> { + tx: &'a mut oneshot::Sender<i32>, +} + +impl<'a> OnClose<'a> { + fn new(tx: &'a mut oneshot::Sender<i32>) -> Self { + OnClose { tx } + } +} + +impl Future for OnClose<'_> { + type Output = bool; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> { + let fut = self.get_mut().tx.closed(); + crate::pin!(fut); + + Ready(fut.poll(cx).is_ready()) + } +} + +#[test] +fn changing_tx_task() { + loom::model(|| { + let (mut tx, rx) = oneshot::channel::<i32>(); + + thread::spawn(move || { + drop(rx); + }); + + let tx = thread::spawn(move || { + let t1 = block_on(OnClose::new(&mut tx)); + + if t1 { + None + } else { + Some(tx) + } + }) + .join() + .unwrap(); + + if let Some(mut tx) = tx { + // Previous task parked, use a new task... + block_on(OnClose::new(&mut tx)); + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs new file mode 100644 index 0000000000..4b5cc7edc6 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs @@ -0,0 +1,105 @@ +use crate::sync::rwlock::*; + +use loom::future::block_on; +use loom::thread; +use std::sync::Arc; + +#[test] +fn concurrent_write() { + let b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write_owned().await; + *guard += 5; + }); + }); + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + //when all threads have finished the value on the lock should be 10 + let guard = block_on(rwlock.read()); + assert_eq!(10, *guard); + }); +} + +#[test] +fn concurrent_read_write() { + let b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write_owned().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t3 = thread::spawn(move || { + block_on(async { + let guard = rwclone.read().await; + //at this state the value on the lock may either be 0, 5, or 10 + assert!(*guard == 0 || *guard == 5 || *guard == 10); + }); + }); + + { + let guard = block_on(rwlock.clone().read_owned()); + //at this state the value on the lock may either be 0, 5, or 10 + assert!(*guard == 0 || *guard == 5 || *guard == 10); + } + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + t3.join().expect("thread 3 read should not panic"); + + let guard = block_on(rwlock.read()); + //when all threads have finished the value on the lock should be 10 + assert_eq!(10, *guard); + }); +} +#[test] +fn downgrade() { + loom::model(|| { + let lock = Arc::new(RwLock::new(1)); + + let n = block_on(lock.write()); + + let cloned_lock = lock.clone(); + let handle = thread::spawn(move || { + let mut guard = block_on(cloned_lock.write()); + *guard = 2; + }); + + let n = n.downgrade(); + assert_eq!(*n, 1); + + drop(n); + handle.join().unwrap(); + assert_eq!(*block_on(lock.read()), 2); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs new file mode 100644 index 0000000000..76a1bc0062 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -0,0 +1,215 @@ +use crate::sync::batch_semaphore::*; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::task::Poll::Ready; +use std::task::{Context, Poll}; + +#[test] +fn basic_usage() { + const NUM: usize = 2; + + struct Shared { + semaphore: Semaphore, + active: AtomicUsize, + } + + async fn actor(shared: Arc<Shared>) { + shared.semaphore.acquire(1).await.unwrap(); + let actual = shared.active.fetch_add(1, SeqCst); + assert!(actual <= NUM - 1); + + let actual = shared.active.fetch_sub(1, SeqCst); + assert!(actual <= NUM); + shared.semaphore.release(1); + } + + loom::model(|| { + let shared = Arc::new(Shared { + semaphore: Semaphore::new(NUM), + active: AtomicUsize::new(0), + }); + + for _ in 0..NUM { + let shared = shared.clone(); + + thread::spawn(move || { + block_on(actor(shared)); + }); + } + + block_on(actor(shared)); + }); +} + +#[test] +fn release() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + { + let semaphore = semaphore.clone(); + thread::spawn(move || { + block_on(semaphore.acquire(1)).unwrap(); + semaphore.release(1); + }); + } + + block_on(semaphore.acquire(1)).unwrap(); + + semaphore.release(1); + }); +} + +#[test] +fn basic_closing() { + const NUM: usize = 2; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + for _ in 0..2 { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + + semaphore.release(1); + } + + Ok::<(), ()>(()) + }); + } + + semaphore.close(); + }); +} + +#[test] +fn concurrent_close() { + const NUM: usize = 3; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + semaphore.release(1); + semaphore.close(); + + Ok::<(), ()>(()) + }); + } + }); +} + +#[test] +fn concurrent_cancel() { + async fn poll_and_cancel(semaphore: Arc<Semaphore>) { + let mut acquire1 = Some(semaphore.acquire(1)); + let mut acquire2 = Some(semaphore.acquire(1)); + poll_fn(|cx| { + // poll the acquire future once, and then immediately throw + // it away. this simulates a situation where a future is + // polled and then cancelled, such as by a timeout. + if let Some(acquire) = acquire1.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + if let Some(acquire) = acquire2.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + Poll::Ready(()) + }) + .await + } + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(0)); + let t1 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t2 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t3 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + + t1.join().unwrap(); + semaphore.release(10); + t2.join().unwrap(); + t3.join().unwrap(); + }); +} + +#[test] +fn batch() { + let mut b = loom::model::Builder::new(); + b.preemption_bound = Some(1); + + b.check(|| { + let semaphore = Arc::new(Semaphore::new(10)); + let active = Arc::new(AtomicUsize::new(0)); + let mut ths = vec![]; + + for _ in 0..2 { + let semaphore = semaphore.clone(); + let active = active.clone(); + + ths.push(thread::spawn(move || { + for n in &[4, 10, 8] { + block_on(semaphore.acquire(*n)).unwrap(); + + active.fetch_add(*n as usize, SeqCst); + + let num_active = active.load(SeqCst); + assert!(num_active <= 10); + + thread::yield_now(); + + active.fetch_sub(*n as usize, SeqCst); + + semaphore.release(*n as usize); + } + })); + } + + for th in ths.into_iter() { + th.join().unwrap(); + } + + assert_eq!(10, semaphore.available_permits()); + }); +} + +#[test] +fn release_during_acquire() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(10)); + semaphore + .try_acquire(8) + .expect("try_acquire should succeed; semaphore uncontended"); + let semaphore2 = semaphore.clone(); + let thread = thread::spawn(move || block_on(semaphore2.acquire(4)).unwrap()); + + semaphore.release(8); + thread.join().unwrap(); + semaphore.release(4); + assert_eq!(10, semaphore.available_permits()); + }) +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_watch.rs b/third_party/rust/tokio/src/sync/tests/loom_watch.rs new file mode 100644 index 0000000000..c575b5b66c --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_watch.rs @@ -0,0 +1,36 @@ +use crate::sync::watch; + +use loom::future::block_on; +use loom::thread; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, mut rx1) = watch::channel(1); + let mut rx2 = rx1.clone(); + let mut rx3 = rx1.clone(); + let mut rx4 = rx1.clone(); + let mut rx5 = rx1.clone(); + + let th = thread::spawn(move || { + tx.send(2).unwrap(); + }); + + block_on(rx1.changed()).unwrap(); + assert_eq!(*rx1.borrow(), 2); + + block_on(rx2.changed()).unwrap(); + assert_eq!(*rx2.borrow(), 2); + + block_on(rx3.changed()).unwrap(); + assert_eq!(*rx3.borrow(), 2); + + block_on(rx4.changed()).unwrap(); + assert_eq!(*rx4.borrow(), 2); + + block_on(rx5.changed()).unwrap(); + assert_eq!(*rx5.borrow(), 2); + + th.join().unwrap(); + }) +} diff --git a/third_party/rust/tokio/src/sync/tests/mod.rs b/third_party/rust/tokio/src/sync/tests/mod.rs new file mode 100644 index 0000000000..ee76418ac5 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/mod.rs @@ -0,0 +1,17 @@ +cfg_not_loom! { + mod atomic_waker; + mod notify; + mod semaphore_batch; +} + +cfg_loom! { + mod loom_atomic_waker; + mod loom_broadcast; + mod loom_list; + mod loom_mpsc; + mod loom_notify; + mod loom_oneshot; + mod loom_semaphore_batch; + mod loom_watch; + mod loom_rwlock; +} diff --git a/third_party/rust/tokio/src/sync/tests/notify.rs b/third_party/rust/tokio/src/sync/tests/notify.rs new file mode 100644 index 0000000000..20153b7a5a --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/notify.rs @@ -0,0 +1,81 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::<Notify>::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +} + +#[test] +fn notify_simple() { + let notify = Notify::new(); + + let mut fut1 = tokio_test::task::spawn(notify.notified()); + assert!(fut1.poll().is_pending()); + + let mut fut2 = tokio_test::task::spawn(notify.notified()); + assert!(fut2.poll().is_pending()); + + notify.notify_waiters(); + + assert!(fut1.poll().is_ready()); + assert!(fut2.poll().is_ready()); +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] +fn watch_test() { + let rt = crate::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + rt.block_on(async { + let (tx, mut rx) = crate::sync::watch::channel(()); + + crate::spawn(async move { + let _ = tx.send(()); + }); + + let _ = rx.changed().await; + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs new file mode 100644 index 0000000000..d529a9e886 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs @@ -0,0 +1,254 @@ +use crate::sync::batch_semaphore::Semaphore; +use tokio_test::*; + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn poll_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 99); +} + +#[test] +fn poll_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 95); + + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn try_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 99); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 98); +} + +#[test] +fn try_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 95); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn poll_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 0); + + let mut acquire_2 = task::spawn(s.acquire(1)); + // Try to acquire the second permit + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn poll_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 4); + + // Try to acquire the second permit + let mut acquire_2 = task::spawn(s.acquire(5)); + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the third permit + let mut acquire_3 = task::spawn(s.acquire(3)); + assert_pending!(acquire_3.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(2); + assert!(acquire_3.is_woken()); + + assert_ready_ok!(acquire_3.poll()); +} + +#[test] +fn try_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 0); + + assert_err!(s.try_acquire(1)); + + s.release(1); + + assert_eq!(s.available_permits(), 1); + assert_ok!(s.try_acquire(1)); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn try_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 4); + + assert_err!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 5); + + assert_ok!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 1); + + s.release(1); + assert_eq!(s.available_permits(), 2); +} + +#[test] +fn poll_acquire_one_zero_permits() { + let s = Semaphore::new(0); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the permit + let mut acquire = task::spawn(s.acquire(1)); + assert_pending!(acquire.poll()); + + s.release(1); + + assert!(acquire.is_woken()); + assert_ready_ok!(acquire.poll()); +} + +#[test] +#[should_panic] +#[cfg(not(target_arch = "wasm32"))] // wasm currently doesn't support unwinding +fn validates_max_permits() { + use std::usize; + Semaphore::new((usize::MAX >> 2) + 1); +} + +#[test] +fn close_semaphore_prevents_acquire() { + let s = Semaphore::new(5); + s.close(); + + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); +} + +#[test] +fn close_semaphore_notifies_permit1() { + let s = Semaphore::new(0); + let mut acquire = task::spawn(s.acquire(1)); + + assert_pending!(acquire.poll()); + + s.close(); + + assert!(acquire.is_woken()); + assert_ready_err!(acquire.poll()); +} + +#[test] +fn close_semaphore_notifies_permit2() { + let s = Semaphore::new(2); + + // Acquire a couple of permits + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + + let mut acquire3 = task::spawn(s.acquire(1)); + let mut acquire4 = task::spawn(s.acquire(1)); + assert_pending!(acquire3.poll()); + assert_pending!(acquire4.poll()); + + s.close(); + + assert!(acquire3.is_woken()); + assert!(acquire4.is_woken()); + + assert_ready_err!(acquire3.poll()); + assert_ready_err!(acquire4.poll()); + + assert_eq!(0, s.available_permits()); + + s.release(1); + + assert_eq!(1, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + + s.release(1); + + assert_eq!(2, s.available_permits()); +} + +#[test] +fn cancel_acquire_releases_permits() { + let s = Semaphore::new(10); + s.try_acquire(4).expect("uncontended try_acquire succeeds"); + assert_eq!(6, s.available_permits()); + + let mut acquire = task::spawn(s.acquire(8)); + assert_pending!(acquire.poll()); + + assert_eq!(0, s.available_permits()); + drop(acquire); + + assert_eq!(6, s.available_permits()); + assert_ok!(s.try_acquire(6)); +} diff --git a/third_party/rust/tokio/src/sync/watch.rs b/third_party/rust/tokio/src/sync/watch.rs new file mode 100644 index 0000000000..5673e0fca7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/watch.rs @@ -0,0 +1,834 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A single-producer, multi-consumer channel that only retains the *last* sent +//! value. +//! +//! This channel is useful for watching for changes to a value from multiple +//! points in the code base, for example, changes to configuration values. +//! +//! # Usage +//! +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::watch; +//! +//! # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +//! let (tx, mut rx) = watch::channel("hello"); +//! +//! tokio::spawn(async move { +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); +//! } +//! }); +//! +//! tx.send("world")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Closing +//! +//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect +//! when all [`Receiver`] handles have been dropped. This indicates that there +//! is no further interest in the values being produced and work can be stopped. +//! +//! # Thread safety +//! +//! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other +//! threads and can be used in a concurrent environment. Clones of [`Receiver`] +//! handles may be moved to separate threads and also used concurrently. +//! +//! [`Sender`]: crate::sync::watch::Sender +//! [`Receiver`]: crate::sync::watch::Receiver +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow +//! [`channel`]: crate::sync::watch::channel +//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed +//! [`Sender::closed`]: crate::sync::watch::Sender::closed + +use crate::sync::notify::Notify; + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::Ordering::Relaxed; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use std::mem; +use std::ops; + +/// Receives values from the associated [`Sender`](struct@Sender). +/// +/// Instances are created by the [`channel`](fn@channel) function. +/// +/// To turn this receiver into a `Stream`, you can use the [`WatchStream`] +/// wrapper. +/// +/// [`WatchStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.WatchStream.html +#[derive(Debug)] +pub struct Receiver<T> { + /// Pointer to the shared state + shared: Arc<Shared<T>>, + + /// Last observed version + version: Version, +} + +/// Sends values to the associated [`Receiver`](struct@Receiver). +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Sender<T> { + shared: Arc<Shared<T>>, +} + +/// Returns a reference to the inner value. +/// +/// Outstanding borrows hold a read lock on the inner value. This means that +/// long lived borrows could cause the produce half to block. It is recommended +/// to keep the borrow as short lived as possible. +/// +/// The priority policy of the lock is dependent on the underlying lock +/// implementation, and this type does not guarantee that any particular policy +/// will be used. In particular, a producer which is waiting to acquire the lock +/// in `send` might or might not block concurrent calls to `borrow`, e.g.: +/// +/// <details><summary>Potential deadlock example</summary> +/// +/// ```text +/// // Task 1 (on thread A) | // Task 2 (on thread B) +/// let _ref1 = rx.borrow(); | +/// | // will block +/// | let _ = tx.send(()); +/// // may deadlock | +/// let _ref2 = rx.borrow(); | +/// ``` +/// </details> +#[derive(Debug)] +pub struct Ref<'a, T> { + inner: RwLockReadGuard<'a, T>, +} + +#[derive(Debug)] +struct Shared<T> { + /// The most recent value. + value: RwLock<T>, + + /// The current version. + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + state: AtomicState, + + /// Tracks the number of `Receiver` instances. + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events. + notify_tx: Notify, +} + +pub mod error { + //! Watch error types. + + use std::fmt; + + /// Error produced when sending a value fails. + #[derive(Debug)] + pub struct SendError<T>(pub T); + + // ===== impl SendError ===== + + impl<T: fmt::Debug> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// Error produced when receiving a change notification. + #[derive(Debug)] + pub struct RecvError(pub(super) ()); + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} +} + +use self::state::{AtomicState, Version}; +mod state { + use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::Ordering::SeqCst; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, SeqCst); + } + } +} + +/// Creates a new watch channel, returning the "send" and "receive" handles. +/// +/// All values sent by [`Sender`] will become visible to the [`Receiver`] handles. +/// Only the last value sent is made available to the [`Receiver`] half. All +/// intermediate values are dropped. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::watch; +/// +/// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +/// let (tx, mut rx) = watch::channel("hello"); +/// +/// tokio::spawn(async move { +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); +/// } +/// }); +/// +/// tx.send("world")?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [`Sender`]: struct@Sender +/// [`Receiver`]: struct@Receiver +pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { + let shared = Arc::new(Shared { + value: RwLock::new(init), + state: AtomicState::new(), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), + }); + + let tx = Sender { + shared: shared.clone(), + }; + + let rx = Receiver { + shared, + version: Version::initial(), + }; + + (tx, rx) +} + +impl<T> Receiver<T> { + fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self { + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Relaxed); + + Self { shared, version } + } + + /// Returns a reference to the most recently sent value. + /// + /// This method does not mark the returned value as seen, so future calls to + /// [`changed`] may return immediately even if you have already seen the + /// value with a call to `borrow`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx.borrow(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx.borrow(); | + /// ``` + /// </details> + /// + /// [`changed`]: Receiver::changed + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (_, rx) = watch::channel("hello"); + /// assert_eq!(*rx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + /// Returns a reference to the most recently sent value and mark that value + /// as seen. + /// + /// This method marks the value as seen, so [`changed`] will not return + /// immediately if the newest value is one previously returned by + /// `borrow_and_update`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx1.borrow_and_update(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx2.borrow_and_update(); | + /// ``` + /// </details> + /// + /// [`changed`]: Receiver::changed + pub fn borrow_and_update(&mut self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + self.version = self.shared.state.load().version(); + Ref { inner } + } + + /// Checks if this channel contains a message that this receiver has not yet + /// seen. The new value is not marked as seen. + /// + /// Although this method is called `has_changed`, it does not check new + /// messages for equality, so this call will return true even if the new + /// message is equal to the old message. + /// + /// Returns an error if the channel has been closed. + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// assert!(rx.has_changed().unwrap()); + /// assert_eq!(*rx.borrow_and_update(), "goodbye"); + /// + /// // The value has been marked as seen + /// assert!(!rx.has_changed().unwrap()); + /// + /// drop(tx); + /// // The `tx` handle has been dropped + /// assert!(rx.has_changed().is_err()); + /// } + /// ``` + pub fn has_changed(&self) -> Result<bool, error::RecvError> { + // Load the version from the state + let state = self.shared.state.load(); + if state.is_closed() { + // The sender has dropped. + return Err(error::RecvError(())); + } + let new_version = state.version(); + + Ok(self.version != new_version) + } + + /// Waits for a change notification, then marks the newest value as seen. + /// + /// If the newest value in the channel has not yet been marked seen when + /// this method is called, the method marks that value seen and returns + /// immediately. If the newest value has already been marked seen, then the + /// method sleeps until a new message is sent by the [`Sender`] connected to + /// this `Receiver`, or until the [`Sender`] is dropped. + /// + /// This method returns an error if and only if the [`Sender`] is dropped. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no values have been marked + /// seen by this call to `changed`. + /// + /// [`Sender`]: struct@Sender + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// tx.send("goodbye").unwrap(); + /// }); + /// + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); + /// + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); + /// } + /// ``` + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = self.shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious + } + } + + cfg_process_driver! { + pub(crate) fn try_has_changed(&mut self) -> Option<Result<(), error::RecvError>> { + maybe_changed(&self.shared, &mut self.version) + } + } +} + +fn maybe_changed<T>( + shared: &Shared<T>, + version: &mut Version, +) -> Option<Result<(), error::RecvError>> { + // Load the version from the state + let state = shared.state.load(); + let new_version = state.version(); + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } + + if state.is_closed() { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); + } + + None +} + +impl<T> Clone for Receiver<T> { + fn clone(&self) -> Self { + let version = self.version; + let shared = self.shared.clone(); + + Self::from_shared(version, shared) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + // No synchronization necessary as this is only used as a counter and + // not memory access. + if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } + } +} + +impl<T> Sender<T> { + /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel has been closed, which happens when + /// every receiver has been dropped. + pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.receiver_count() { + return Err(error::SendError(value)); + } + + self.send_replace(value); + Ok(()) + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, value: T) -> T { + let old = { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.write().unwrap(); + let old = mem::replace(&mut *lock, value); + + self.shared.state.increment_version(); + + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + + old + }; + + // Notify all watchers + self.shared.notify_rx.notify_waiters(); + + old + } + + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _) = watch::channel("hello"); + /// assert_eq!(*tx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + /// Checks if the channel has been closed. This happens when all receivers + /// have dropped. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(()); + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.receiver_count() == 0 + } + + /// Completes when all receivers have dropped. + /// + /// This allows the producer to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); + + if self.receiver_count() == 0 { + return; + } + + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } + } + + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ``` + /// use tokio::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + let version = shared.state.load().version(); + + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) + } + + /// Returns the number of receivers that currently exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx1) = watch::channel("hello"); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = rx1.clone(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + self.shared.ref_count_rx.load(Relaxed) + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + self.shared.state.set_closed(); + self.shared.notify_rx.notify_waiters(); + } +} + +// ===== impl Ref ===== + +impl<T> ops::Deref for Ref<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.inner.deref() + } +} + +#[cfg(all(test, loom))] +mod tests { + use futures::future::FutureExt; + use loom::thread; + + // test for https://github.com/tokio-rs/tokio/issues/3168 + #[test] + fn watch_spurious_wakeup() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + send.send(1).unwrap(); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let mut recv = recv_thread.join().unwrap(); + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + }); + + recv.changed().now_or_never(); + + send_thread.join().unwrap(); + }); + } + + #[test] + fn watch_borrow() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + assert!(send.borrow().eq(&0)); + assert!(recv.borrow().eq(&0)); + + send.send(1).unwrap(); + assert!(send.borrow().eq(&1)); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let recv = recv_thread.join().unwrap(); + assert!(recv.borrow().eq(&3)); + assert!(send.borrow().eq(&3)); + + send.send(2).unwrap(); + + thread::spawn(move || { + assert!(recv.borrow().eq(&2)); + }); + assert!(send.borrow().eq(&2)); + }); + } +} |