diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-30 03:57:31 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-30 03:57:31 +0000 |
commit | dc0db358abe19481e475e10c32149b53370f1a1c (patch) | |
tree | ab8ce99c4b255ce46f99ef402c27916055b899ee /vendor/tokio/src/sync | |
parent | Releasing progress-linux version 1.71.1+dfsg1-2~progress7.99u1. (diff) | |
download | rustc-dc0db358abe19481e475e10c32149b53370f1a1c.tar.xz rustc-dc0db358abe19481e475e10c32149b53370f1a1c.zip |
Merging upstream version 1.72.1+dfsg1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'vendor/tokio/src/sync')
33 files changed, 5825 insertions, 913 deletions
diff --git a/vendor/tokio/src/sync/barrier.rs b/vendor/tokio/src/sync/barrier.rs index 0e39dac8b..29b6d4e48 100644 --- a/vendor/tokio/src/sync/barrier.rs +++ b/vendor/tokio/src/sync/barrier.rs @@ -1,5 +1,7 @@ use crate::loom::sync::Mutex; use crate::sync::watch; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; /// A barrier enables multiple tasks to synchronize the beginning of some computation. /// @@ -41,6 +43,8 @@ pub struct Barrier { state: Mutex<BarrierState>, wait: watch::Receiver<usize>, n: usize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, } #[derive(Debug)] @@ -55,6 +59,7 @@ impl Barrier { /// /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all /// tasks at once when the `n`th task calls `wait`. + #[track_caller] pub fn new(mut n: usize) -> Barrier { let (waker, wait) = crate::sync::watch::channel(0); @@ -65,6 +70,32 @@ impl Barrier { n = 1; } + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Barrier", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + size = n, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 0, + ) + }); + resource_span + }; + Barrier { state: Mutex::new(BarrierState { waker, @@ -73,6 +104,8 @@ impl Barrier { }), n, wait, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -85,10 +118,26 @@ impl Barrier { /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks /// will receive a result that will return `false` from `is_leader`. pub async fn wait(&self) -> BarrierWaitResult { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace::async_op( + || self.wait_internal(), + self.resource_span.clone(), + "Barrier::wait", + "poll", + false, + ) + .await; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return self.wait_internal().await; + } + async fn wait_internal(&self) -> BarrierWaitResult { + crate::trace::async_trace_leaf().await; + // NOTE: we are taking a _synchronous_ lock here. // It is okay to do so because the critical section is fast and never yields, so it cannot // deadlock even if another future is concurrently holding the lock. - // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than + // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than // the asynchronous counter-parts, so we should use them where possible [citation needed]. // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across // a yield point, and thus marks the returned future as !Send. @@ -96,7 +145,23 @@ impl Barrier { let mut state = self.state.lock(); let generation = state.generation; state.arrived += 1; + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 1, + arrived.op = "add", + ); + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + arrived = true, + ); if state.arrived == self.n { + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + is_leader = true, + ); // we are the leader for this generation // wake everyone, increment the generation, and return state diff --git a/vendor/tokio/src/sync/batch_semaphore.rs b/vendor/tokio/src/sync/batch_semaphore.rs index a0bf5ef94..a762f799d 100644 --- a/vendor/tokio/src/sync/batch_semaphore.rs +++ b/vendor/tokio/src/sync/batch_semaphore.rs @@ -1,5 +1,5 @@ #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] -//! # Implementation Details +//! # Implementation Details. //! //! The semaphore is implemented using an intrusive linked list of waiters. An //! atomic counter tracks the number of available permits. If the semaphore does @@ -19,6 +19,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Mutex, MutexGuard}; use crate::util::linked_list::{self, LinkedList}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use crate::util::WakeList; use std::future::Future; use std::marker::PhantomPinned; @@ -34,6 +37,8 @@ pub(crate) struct Semaphore { waiters: Mutex<Waitlist>, /// The current number of available permits in the semaphore. permits: AtomicUsize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, } struct Waitlist { @@ -44,7 +49,7 @@ struct Waitlist { /// Error returned from the [`Semaphore::try_acquire`] function. /// /// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] pub enum TryAcquireError { /// The semaphore has been [closed] and cannot issue new permits. /// @@ -100,10 +105,21 @@ struct Waiter { /// use `UnsafeCell` internally. pointers: linked_list::Pointers<Waiter>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx: trace::AsyncOpTracingCtx, + /// Should not be `Unpin`. _p: PhantomPinned, } +generate_addr_of_methods! { + impl<> Waiter { + unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { + &self.pointers + } + } +} + impl Semaphore { /// The maximum number of permits which a semaphore can hold. /// @@ -128,16 +144,38 @@ impl Semaphore { "a semaphore may not have more than MAX_PERMITS permits ({})", Self::MAX_PERMITS ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + is_internal = true + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = permits, + permits.op = "override", + ) + }); + resource_span + }; + Self { permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), waiters: Mutex::new(Waitlist { queue: LinkedList::new(), closed: false, }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } - /// Creates a new semaphore with the initial number of permits + /// Creates a new semaphore with the initial number of permits. /// /// Maximum number of permits on 32-bit platforms is `1<<29`. /// @@ -155,10 +193,12 @@ impl Semaphore { queue: LinkedList::new(), closed: false, }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } - /// Returns the current number of available permits + /// Returns the current number of available permits. pub(crate) fn available_permits(&self) -> usize { self.permits.load(Acquire) >> Self::PERMIT_SHIFT } @@ -196,7 +236,7 @@ impl Semaphore { } } - /// Returns true if the semaphore is closed + /// Returns true if the semaphore is closed. pub(crate) fn is_closed(&self) -> bool { self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED } @@ -223,7 +263,10 @@ impl Semaphore { let next = curr - num_permits; match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { - Ok(_) => return Ok(()), + Ok(_) => { + // TODO: Instrument once issue has been solved + return Ok(()); + } Err(actual) => curr = actual, } } @@ -239,12 +282,12 @@ impl Semaphore { /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { - let mut wakers: [Option<Waker>; 8] = Default::default(); + let mut wakers = WakeList::new(); let mut lock = Some(waiters); let mut is_empty = false; while rem > 0 { let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); - 'inner: for slot in &mut wakers[..] { + 'inner: while wakers.can_push() { // Was the waiter assigned enough permits to wake it? match waiters.queue.last() { Some(waiter) => { @@ -260,7 +303,11 @@ impl Semaphore { } }; let mut waiter = waiters.queue.pop_back().unwrap(); - *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = + unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } + { + wakers.push(waker); + } } if rem > 0 && is_empty { @@ -278,15 +325,23 @@ impl Semaphore { rem, Self::MAX_PERMITS ); + + // add remaining permits back + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = rem, + permits.op = "add", + ) + }); + rem = 0; } drop(waiters); // release the lock - wakers - .iter_mut() - .filter_map(Option::take) - .for_each(Waker::wake); + wakers.wake_all(); } assert_eq!(rem, 0); @@ -345,6 +400,20 @@ impl Semaphore { acquired += acq; if remaining == 0 { if !queued { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ); + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = acquired, + permits.op = "add", + ) + }); + return Ready(Ok(())); } else if lock.is_none() { break self.waiters.lock(); @@ -360,12 +429,22 @@ impl Semaphore { return Ready(Err(AcquireError::closed())); } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ) + }); + if node.assign_permits(&mut acquired) { self.add_permits_locked(acquired, waiters); return Ready(Ok(())); } assert_eq!(acquired, 0); + let mut old_waker = None; // Otherwise, register the waker & enqueue the node. node.waker.with_mut(|waker| { @@ -377,7 +456,7 @@ impl Semaphore { .map(|waker| !waker.will_wake(cx.waker())) .unwrap_or(true) { - *waker = Some(cx.waker().clone()); + old_waker = std::mem::replace(waker, Some(cx.waker().clone())); } }); @@ -390,6 +469,8 @@ impl Semaphore { waiters.queue.push_front(node); } + drop(waiters); + drop(old_waker); Pending } @@ -404,11 +485,16 @@ impl fmt::Debug for Semaphore { } impl Waiter { - fn new(num_permits: u32) -> Self { + fn new( + num_permits: u32, + #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx, + ) -> Self { Waiter { waker: UnsafeCell::new(None), state: AtomicUsize::new(num_permits as usize), pointers: linked_list::Pointers::new(), + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx, _p: PhantomPinned, } } @@ -424,6 +510,14 @@ impl Waiter { match self.state.compare_exchange(curr, next, AcqRel, Acquire) { Ok(_) => { *n -= assign; + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.ctx.async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = assign, + permits.op = "add", + ); + }); return next == 0; } Err(actual) => curr = actual, @@ -436,12 +530,26 @@ impl Future for Acquire<'_> { type Output = Result<(), AcquireError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - // First, ensure the current task has enough budget to proceed. - let coop = ready!(crate::coop::poll_proceed(cx)); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _resource_span = self.node.ctx.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_span = self.node.ctx.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered(); let (node, semaphore, needed, queued) = self.project(); - match semaphore.poll_acquire(cx, needed, node, *queued) { + // First, ensure the current task has enough budget to proceed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let coop = ready!(trace_poll_op!( + "poll_acquire", + crate::runtime::coop::poll_proceed(cx), + )); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + + let result = match semaphore.poll_acquire(cx, needed, node, *queued) { Pending => { *queued = true; Pending @@ -452,18 +560,59 @@ impl Future for Acquire<'_> { *queued = false; Ready(Ok(())) } - } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace_poll_op!("poll_acquire", result); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + return result; } } impl<'a> Acquire<'a> { fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self { - Self { + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { node: Waiter::new(num_permits), semaphore, num_permits, queued: false, - } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return semaphore.resource_span.in_scope(|| { + let async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Acquire::new"); + let async_op_poll_span = async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_requested = num_permits, + permits.op = "override", + ); + + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = 0usize, + permits.op = "override", + ); + + tracing::trace_span!("runtime.resource.async_op.poll") + }); + + let ctx = trace::AsyncOpTracingCtx { + async_op_span, + async_op_poll_span, + resource_span: semaphore.resource_span.clone(), + }; + + Self { + node: Waiter::new(num_permits, ctx), + semaphore, + num_permits, + queued: false, + } + }); } fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) { @@ -478,7 +627,7 @@ impl<'a> Acquire<'a> { let this = self.get_unchecked_mut(); ( Pin::new_unchecked(&mut this.node), - &this.semaphore, + this.semaphore, this.num_permits, &mut this.queued, ) @@ -566,12 +715,6 @@ impl std::error::Error for TryAcquireError {} /// /// `Waiter` is forced to be !Unpin. unsafe impl linked_list::Link for Waiter { - // XXX: ideally, we would be able to use `Pin` here, to enforce the - // invariant that list entries may not move while in the list. However, we - // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` - // type would require `Semaphore` to be generic over a lifetime. We can't - // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether - // or not they dereference to an `!Unpin` target. type Handle = NonNull<Waiter>; type Target = Waiter; @@ -583,7 +726,7 @@ unsafe impl linked_list::Link for Waiter { ptr } - unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { - NonNull::from(&mut target.as_mut().pointers) + unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + Waiter::addr_of_pointers(target) } } diff --git a/vendor/tokio/src/sync/broadcast.rs b/vendor/tokio/src/sync/broadcast.rs index a2ca4459e..4b36452ce 100644 --- a/vendor/tokio/src/sync/broadcast.rs +++ b/vendor/tokio/src/sync/broadcast.rs @@ -4,7 +4,7 @@ //! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] //! values. [`Sender`] handles are clone-able, allowing concurrent send and //! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as -//! long as `T` is also `Send` or `Sync` respectively. +//! long as `T` is `Send`. //! //! When a value is sent, **all** [`Receiver`] handles are notified and will //! receive the value. The value is stored once inside the channel and cloned on @@ -18,6 +18,9 @@ //! returned [`Receiver`] will receive values sent **after** the call to //! `subscribe`. //! +//! This channel is also suitable for the single-producer multi-consumer +//! use-case, where a single sender broadcasts values to many receivers. +//! //! ## Lagging //! //! As sent messages must be retained until **all** [`Receiver`] handles receive @@ -51,6 +54,10 @@ //! all values retained by the channel, the next call to [`recv`] will return //! with [`RecvError::Closed`]. //! +//! When a [`Receiver`] handle is dropped, any messages not read by the receiver +//! will be marked as read. If this receiver was the only one not to have read +//! that message, the message will be dropped at this point. +//! //! [`Sender`]: crate::sync::broadcast::Sender //! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe //! [`Receiver`]: crate::sync::broadcast::Receiver @@ -111,8 +118,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; use std::fmt; use std::future::Future; @@ -230,7 +238,7 @@ pub mod error { /// /// [`recv`]: crate::sync::broadcast::Receiver::recv /// [`Receiver`]: crate::sync::broadcast::Receiver - #[derive(Debug, PartialEq)] + #[derive(Debug, PartialEq, Eq, Clone)] pub enum RecvError { /// There are no more active senders implying no further messages will ever /// be sent. @@ -258,7 +266,7 @@ pub mod error { /// /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv /// [`Receiver`]: crate::sync::broadcast::Receiver - #[derive(Debug, PartialEq)] + #[derive(Debug, PartialEq, Eq, Clone)] pub enum TryRecvError { /// The channel is currently empty. There are still active /// [`Sender`] handles, so data may yet become available. @@ -293,37 +301,37 @@ pub mod error { use self::error::*; -/// Data shared between senders and receivers +/// Data shared between senders and receivers. struct Shared<T> { - /// slots in the channel + /// slots in the channel. buffer: Box<[RwLock<Slot<T>>]>, - /// Mask a position -> index + /// Mask a position -> index. mask: usize, /// Tail of the queue. Includes the rx wait list. tail: Mutex<Tail>, - /// Number of outstanding Sender handles + /// Number of outstanding Sender handles. num_tx: AtomicUsize, } -/// Next position to write a value +/// Next position to write a value. struct Tail { - /// Next position to write to + /// Next position to write to. pos: u64, - /// Number of active receivers + /// Number of active receivers. rx_cnt: usize, - /// True if the channel is closed + /// True if the channel is closed. closed: bool, - /// Receivers waiting for a value + /// Receivers waiting for a value. waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, } -/// Slot in the buffer +/// Slot in the buffer. struct Slot<T> { /// Remaining number of receivers that are expected to see this value. /// @@ -333,12 +341,9 @@ struct Slot<T> { /// acquired. rem: AtomicUsize, - /// Uniquely identifies the `send` stored in the slot + /// Uniquely identifies the `send` stored in the slot. pos: u64, - /// True signals the channel is closed. - closed: bool, - /// The value being broadcast. /// /// The value is set by `send` when the write lock is held. When a reader @@ -346,9 +351,9 @@ struct Slot<T> { val: UnsafeCell<Option<T>>, } -/// An entry in the wait queue +/// An entry in the wait queue. struct Waiter { - /// True if queued + /// True if queued. queued: bool, /// Task waiting on the broadcast channel. @@ -361,16 +366,24 @@ struct Waiter { _p: PhantomPinned, } +generate_addr_of_methods! { + impl<> Waiter { + unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { + &self.pointers + } + } +} + struct RecvGuard<'a, T> { slot: RwLockReadGuard<'a, Slot<T>>, } -/// Receive a value future +/// Receive a value future. struct Recv<'a, T> { - /// Receiver being waited on + /// Receiver being waited on. receiver: &'a mut Receiver<T>, - /// Entry in the waiter `LinkedList` + /// Entry in the waiter `LinkedList`. waiter: UnsafeCell<Waiter>, } @@ -425,6 +438,12 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// tx.send(20).unwrap(); /// } /// ``` +/// +/// # Panics +/// +/// This will panic if `capacity` is equal to `0` or larger +/// than `usize::MAX / 2`. +#[track_caller] pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { assert!(capacity > 0, "capacity is empty"); assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); @@ -438,7 +457,6 @@ pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), pos: (i as u64).wrapping_sub(capacity as u64), - closed: false, val: UnsafeCell::new(None), })); } @@ -523,8 +541,41 @@ impl<T> Sender<T> { /// } /// ``` pub fn send(&self, value: T) -> Result<usize, SendError<T>> { - self.send2(Some(value)) - .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + let mut tail = self.shared.tail.lock(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let mut slot = self.shared.buffer[idx].write().unwrap(); + + // Track the position + slot.pos = pos; + + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); + + // Write the value + slot.val = UnsafeCell::new(Some(value)); + + // Release the slot lock before notifying the receivers. + drop(slot); + + // Notify and release the mutex. This must happen after the slot lock is + // released, otherwise the writer lock bit could be cleared while another + // thread is in the critical section. + self.shared.notify_rx(tail); + + Ok(rem) } /// Creates a new [`Receiver`] handle that will receive values sent **after** @@ -555,6 +606,97 @@ impl<T> Sender<T> { new_receiver(shared) } + /// Returns the number of queued values. + /// + /// A value is queued until it has either been seen by all receivers that were alive at the time + /// it was sent, or has been evicted from the queue by subsequent sends that exceeded the + /// queue's capacity. + /// + /// # Note + /// + /// In contrast to [`Receiver::len`], this method only reports queued values and not values that + /// have been evicted from the queue before being seen by all receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// assert_eq!(tx.len(), 3); + /// + /// rx1.recv().await.unwrap(); + /// + /// // The len is still 3 since rx2 hasn't seen the first value yet. + /// assert_eq!(tx.len(), 3); + /// + /// rx2.recv().await.unwrap(); + /// + /// assert_eq!(tx.len(), 2); + /// } + /// ``` + pub fn len(&self) -> usize { + let tail = self.shared.tail.lock(); + + let base_idx = (tail.pos & self.shared.mask as u64) as usize; + let mut low = 0; + let mut high = self.shared.buffer.len(); + while low < high { + let mid = low + (high - low) / 2; + let idx = base_idx.wrapping_add(mid) & self.shared.mask; + if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 { + low = mid + 1; + } else { + high = mid; + } + } + + self.shared.buffer.len() - low + } + + /// Returns true if there are no queued values. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// assert!(tx.is_empty()); + /// + /// tx.send(10).unwrap(); + /// + /// assert!(!tx.is_empty()); + /// + /// rx1.recv().await.unwrap(); + /// + /// // The queue is still not empty since rx2 hasn't seen the value. + /// assert!(!tx.is_empty()); + /// + /// rx2.recv().await.unwrap(); + /// + /// assert!(tx.is_empty()); + /// } + /// ``` + pub fn is_empty(&self) -> bool { + let tail = self.shared.tail.lock(); + + let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize; + self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 + } + /// Returns the number of active receivers /// /// An active receiver is a [`Receiver`] handle returned from [`channel`] or @@ -596,52 +738,38 @@ impl<T> Sender<T> { tail.rx_cnt } - fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { - let mut tail = self.shared.tail.lock(); - - if tail.rx_cnt == 0 { - return Err(SendError(value)); - } - - // Position to write into - let pos = tail.pos; - let rem = tail.rx_cnt; - let idx = (pos & self.shared.mask as u64) as usize; - - // Update the tail position - tail.pos = tail.pos.wrapping_add(1); - - // Get the slot - let mut slot = self.shared.buffer[idx].write().unwrap(); - - // Track the position - slot.pos = pos; - - // Set remaining receivers - slot.rem.with_mut(|v| *v = rem); - - // Set the closed bit if the value is `None`; otherwise write the value - if value.is_none() { - tail.closed = true; - slot.closed = true; - } else { - slot.val.with_mut(|ptr| unsafe { *ptr = value }); - } - - // Release the slot lock before notifying the receivers. - drop(slot); - - tail.notify_rx(); + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = broadcast::channel::<()>(16); + /// let tx2 = tx.clone(); + /// + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, _rx3) = broadcast::channel::<()>(16); + /// + /// assert!(!tx3.same_channel(&tx2)); + /// } + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } - // Release the mutex. This must happen after the slot lock is released, - // otherwise the writer lock bit could be cleared while another thread - // is in the critical section. - drop(tail); + fn close_channel(&self) { + let mut tail = self.shared.tail.lock(); + tail.closed = true; - Ok(rem) + self.shared.notify_rx(tail); } } +/// Create a new `Receiver` which reads starting from the tail. fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { let mut tail = shared.tail.lock(); @@ -658,18 +786,47 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { Receiver { shared, next } } -impl Tail { - fn notify_rx(&mut self) { - while let Some(mut waiter) = self.waiters.pop_back() { - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; +impl<T> Shared<T> { + fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) { + let mut wakers = WakeList::new(); + 'outer: loop { + while wakers.can_push() { + match tail.waiters.pop_back() { + Some(mut waiter) => { + // Safety: `tail` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.queued); + waiter.queued = false; + + if let Some(waker) = waiter.waker.take() { + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + // Release the lock before waking. + drop(tail); - assert!(waiter.queued); - waiter.queued = false; + // Before we acquire the lock again all sorts of things can happen: + // some waiters may remove themselves from the list and new waiters + // may be added. This is fine since at worst we will unnecessarily + // wake up waiters which will then queue themselves again. - let waker = waiter.waker.take().unwrap(); - waker.wake(); + wakers.wake_all(); + + // Acquire the lock again. + tail = self.tail.lock(); } + + // Release the lock before waking. + drop(tail); + + wakers.wake_all(); } } @@ -685,12 +842,102 @@ impl<T> Clone for Sender<T> { impl<T> Drop for Sender<T> { fn drop(&mut self) { if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { - let _ = self.send2(None); + self.close_channel(); } } } impl<T> Receiver<T> { + /// Returns the number of messages that were sent into the channel and that + /// this [`Receiver`] has yet to receive. + /// + /// If the returned value from `len` is larger than the next largest power of 2 + /// of the capacity of the channel any call to [`recv`] will return an + /// `Err(RecvError::Lagged)` and any call to [`try_recv`] will return an + /// `Err(TryRecvError::Lagged)`, e.g. if the capacity of the channel is 10, + /// [`recv`] will start to return `Err(RecvError::Lagged)` once `len` returns + /// values larger than 16. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// + /// assert_eq!(rx1.len(), 2); + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.len(), 1); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert_eq!(rx1.len(), 0); + /// } + /// ``` + pub fn len(&self) -> usize { + let next_send_pos = self.shared.tail.lock().pos; + (next_send_pos - self.next) as usize + } + + /// Returns true if there aren't any messages in the channel that the [`Receiver`] + /// has yet to receive. + /// + /// [`Receiver]: create::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// + /// assert!(rx1.is_empty()); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// + /// assert!(!rx1.is_empty()); + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert!(rx1.is_empty()); + /// } + /// ``` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns `true` if receivers belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel::<()>(16); + /// let rx2 = tx.subscribe(); + /// + /// assert!(rx.same_channel(&rx2)); + /// + /// let (_tx3, rx3) = broadcast::channel::<()>(16); + /// + /// assert!(!rx3.same_channel(&rx2)); + /// } + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + /// Locks the next value if there is one. fn recv_ref( &mut self, @@ -702,14 +949,6 @@ impl<T> Receiver<T> { let mut slot = self.shared.buffer[idx].read().unwrap(); if slot.pos != self.next { - let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); - - // The receiver has read all current values in the channel and there - // is no waiter to register - if waiter.is_none() && next_pos == self.next { - return Err(TryRecvError::Empty); - } - // Release the `slot` lock before attempting to acquire the `tail` // lock. This is required because `send2` acquires the tail lock // first followed by the slot lock. Acquiring the locks in reverse @@ -719,6 +958,8 @@ impl<T> Receiver<T> { // the slot lock. drop(slot); + let mut old_waker = None; + let mut tail = self.shared.tail.lock(); // Acquire slot lock again @@ -731,6 +972,13 @@ impl<T> Receiver<T> { let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); if next_pos == self.next { + // At this point the channel is empty for *this* receiver. If + // it's been closed, then that's what we return, otherwise we + // set a waker and return empty. + if tail.closed { + return Err(TryRecvError::Closed); + } + // Store the waker if let Some((waiter, waker)) = waiter { // Safety: called while locked. @@ -744,7 +992,10 @@ impl<T> Receiver<T> { match (*ptr).waker { Some(ref w) if w.will_wake(waker) => {} _ => { - (*ptr).waker = Some(waker.clone()); + old_waker = std::mem::replace( + &mut (*ptr).waker, + Some(waker.clone()), + ); } } @@ -756,6 +1007,11 @@ impl<T> Receiver<T> { } } + // Drop the old waker after releasing the locks. + drop(slot); + drop(tail); + drop(old_waker); + return Err(TryRecvError::Empty); } @@ -764,22 +1020,7 @@ impl<T> Receiver<T> { // catch up by skipping dropped messages and setting the // internal cursor to the **oldest** message stored by the // channel. - // - // However, finding the oldest position is a bit more - // complicated than `tail-position - buffer-size`. When - // the channel is closed, the tail position is incremented to - // signal a new `None` message, but `None` is not stored in the - // channel itself (see issue #2425 for why). - // - // To account for this, if the channel is closed, the tail - // position is decremented by `buffer-size + 1`. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); let missed = next.wrapping_sub(self.next); @@ -800,15 +1041,38 @@ impl<T> Receiver<T> { self.next = self.next.wrapping_add(1); - if slot.closed { - return Err(TryRecvError::Closed); - } - Ok(RecvGuard { slot }) } } impl<T: Clone> Receiver<T> { + /// Re-subscribes to the channel starting from the current tail element. + /// + /// This [`Receiver`] handle will receive a clone of all values sent + /// **after** it has resubscribed. This will not include elements that are + /// in the queue of the current receiver. Consider the following example. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(1).unwrap(); + /// let mut rx2 = rx.resubscribe(); + /// tx.send(2).unwrap(); + /// + /// assert_eq!(rx2.recv().await.unwrap(), 2); + /// assert_eq!(rx.recv().await.unwrap(), 1); + /// } + /// ``` + pub fn resubscribe(&self) -> Self { + let shared = self.shared.clone(); + new_receiver(shared) + } /// Receives the next value for this receiver. /// /// Each [`Receiver`] handle will receive a clone of all values sent @@ -930,6 +1194,33 @@ impl<T: Clone> Receiver<T> { let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// ``` + /// use std::thread; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(rx.blocking_recv(), Ok(10)); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + pub fn blocking_recv(&mut self) -> Result<T, RecvError> { + crate::future::block_on(self.recv()) + } } impl<T> Drop for Receiver<T> { @@ -988,6 +1279,8 @@ where type Output = Result<T, RecvError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + ready!(crate::trace::trace_leaf(cx)); + let (receiver, waiter) = self.project(); let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { @@ -1039,8 +1332,8 @@ unsafe impl linked_list::Link for Waiter { ptr } - unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { - NonNull::from(&mut target.as_mut().pointers) + unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + Waiter::addr_of_pointers(target) } } diff --git a/vendor/tokio/src/sync/mod.rs b/vendor/tokio/src/sync/mod.rs index 457e6ab29..70fd9b9e3 100644 --- a/vendor/tokio/src/sync/mod.rs +++ b/vendor/tokio/src/sync/mod.rs @@ -94,6 +94,10 @@ //! producers to a single consumer. This channel is often used to send work to a //! task or to receive the result of many computations. //! +//! This is also the channel you should use if you want to send many messages +//! from a single producer to a single consumer. There is no dedicated spsc +//! channel. +//! //! **Example:** using an mpsc to incrementally stream the results of a series //! of computations. //! @@ -244,6 +248,10 @@ //! This channel tends to be used less often than `oneshot` and `mpsc` but still //! has its use cases. //! +//! This is also the channel you should use if you want to broadcast values from +//! a single producer to many consumers. There is no dedicated spmc broadcast +//! channel. +//! //! Basic usage //! //! ``` @@ -441,7 +449,7 @@ cfg_sync! { pub mod mpsc; mod mutex; - pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard, MappedMutexGuard}; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard, MappedMutexGuard, OwnedMappedMutexGuard}; pub(crate) mod notify; pub use notify::Notify; diff --git a/vendor/tokio/src/sync/mpsc/block.rs b/vendor/tokio/src/sync/mpsc/block.rs index 1c9ab14e9..39c3e1be2 100644 --- a/vendor/tokio/src/sync/mpsc/block.rs +++ b/vendor/tokio/src/sync/mpsc/block.rs @@ -1,7 +1,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; +use std::alloc::Layout; use std::mem::MaybeUninit; use std::ops; use std::ptr::{self, NonNull}; @@ -11,6 +11,17 @@ use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Release}; /// /// Each block in the list can hold up to `BLOCK_CAP` messages. pub(crate) struct Block<T> { + /// The header fields. + header: BlockHeader<T>, + + /// Array containing values pushed into the block. Values are stored in a + /// continuous array in order to improve cache line behavior when reading. + /// The values must be manually dropped. + values: Values<T>, +} + +/// Extra fields for a `Block<T>`. +struct BlockHeader<T> { /// The start index of this block. /// /// Slots in this block have indices in `start_index .. start_index + BLOCK_CAP`. @@ -25,11 +36,6 @@ pub(crate) struct Block<T> { /// The observed `tail_position` value *after* the block has been passed by /// `block_tail`. observed_tail_position: UnsafeCell<usize>, - - /// Array containing values pushed into the block. Values are stored in a - /// continuous array in order to improve cache line behavior when reading. - /// The values must be manually dropped. - values: Values<T>, } pub(crate) enum Read<T> { @@ -37,11 +43,12 @@ pub(crate) enum Read<T> { Closed, } +#[repr(transparent)] struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); use super::BLOCK_CAP; -/// Masks an index to get the block identifier +/// Masks an index to get the block identifier. const BLOCK_MASK: usize = !(BLOCK_CAP - 1); /// Masks an index to get the value offset in a block. @@ -72,28 +79,56 @@ pub(crate) fn offset(slot_index: usize) -> usize { SLOT_MASK & slot_index } +generate_addr_of_methods! { + impl<T> Block<T> { + unsafe fn addr_of_header(self: NonNull<Self>) -> NonNull<BlockHeader<T>> { + &self.header + } + + unsafe fn addr_of_values(self: NonNull<Self>) -> NonNull<Values<T>> { + &self.values + } + } +} + impl<T> Block<T> { - pub(crate) fn new(start_index: usize) -> Block<T> { - Block { - // The absolute index in the channel of the first slot in the block. - start_index, + pub(crate) fn new(start_index: usize) -> Box<Block<T>> { + unsafe { + // Allocate the block on the heap. + // SAFETY: The size of the Block<T> is non-zero, since it is at least the size of the header. + let block = std::alloc::alloc(Layout::new::<Block<T>>()) as *mut Block<T>; + let block = match NonNull::new(block) { + Some(block) => block, + None => std::alloc::handle_alloc_error(Layout::new::<Block<T>>()), + }; + + // Write the header to the block. + Block::addr_of_header(block).as_ptr().write(BlockHeader { + // The absolute index in the channel of the first slot in the block. + start_index, - // Pointer to the next block in the linked list. - next: AtomicPtr::new(ptr::null_mut()), + // Pointer to the next block in the linked list. + next: AtomicPtr::new(ptr::null_mut()), - ready_slots: AtomicUsize::new(0), + ready_slots: AtomicUsize::new(0), - observed_tail_position: UnsafeCell::new(0), + observed_tail_position: UnsafeCell::new(0), + }); - // Value storage - values: unsafe { Values::uninitialized() }, + // Initialize the values array. + Values::initialize(Block::addr_of_values(block)); + + // Convert the pointer to a `Box`. + // Safety: The raw pointer was allocated using the global allocator, and with + // the layout for a `Block<T>`, so it's valid to convert it to box. + Box::from_raw(block.as_ptr()) } } - /// Returns `true` if the block matches the given index + /// Returns `true` if the block matches the given index. pub(crate) fn is_at_index(&self, index: usize) -> bool { debug_assert!(offset(index) == 0); - self.start_index == index + self.header.start_index == index } /// Returns the number of blocks between `self` and the block at the @@ -102,7 +137,7 @@ impl<T> Block<T> { /// `start_index` must represent a block *after* `self`. pub(crate) fn distance(&self, other_index: usize) -> usize { debug_assert!(offset(other_index) == 0); - other_index.wrapping_sub(self.start_index) / BLOCK_CAP + other_index.wrapping_sub(self.header.start_index) / BLOCK_CAP } /// Reads the value at the given offset. @@ -117,7 +152,7 @@ impl<T> Block<T> { pub(crate) unsafe fn read(&self, slot_index: usize) -> Option<Read<T>> { let offset = offset(slot_index); - let ready_bits = self.ready_slots.load(Acquire); + let ready_bits = self.header.ready_slots.load(Acquire); if !is_ready(ready_bits, offset) { if is_tx_closed(ready_bits) { @@ -157,7 +192,7 @@ impl<T> Block<T> { /// Signal to the receiver that the sender half of the list is closed. pub(crate) unsafe fn tx_close(&self) { - self.ready_slots.fetch_or(TX_CLOSED, Release); + self.header.ready_slots.fetch_or(TX_CLOSED, Release); } /// Resets the block to a blank state. This enables reusing blocks in the @@ -170,9 +205,9 @@ impl<T> Block<T> { /// * All slots are empty. /// * The caller holds a unique pointer to the block. pub(crate) unsafe fn reclaim(&mut self) { - self.start_index = 0; - self.next = AtomicPtr::new(ptr::null_mut()); - self.ready_slots = AtomicUsize::new(0); + self.header.start_index = 0; + self.header.next = AtomicPtr::new(ptr::null_mut()); + self.header.ready_slots = AtomicUsize::new(0); } /// Releases the block to the rx half for freeing. @@ -188,19 +223,20 @@ impl<T> Block<T> { pub(crate) unsafe fn tx_release(&self, tail_position: usize) { // Track the observed tail_position. Any sender targeting a greater // tail_position is guaranteed to not access this block. - self.observed_tail_position + self.header + .observed_tail_position .with_mut(|ptr| *ptr = tail_position); // Set the released bit, signalling to the receiver that it is safe to // free the block's memory as soon as all slots **prior** to // `observed_tail_position` have been filled. - self.ready_slots.fetch_or(RELEASED, Release); + self.header.ready_slots.fetch_or(RELEASED, Release); } /// Mark a slot as ready fn set_ready(&self, slot: usize) { let mask = 1 << slot; - self.ready_slots.fetch_or(mask, Release); + self.header.ready_slots.fetch_or(mask, Release); } /// Returns `true` when all slots have their `ready` bits set. @@ -215,25 +251,31 @@ impl<T> Block<T> { /// single atomic cell. However, this could have negative impact on cache /// behavior as there would be many more mutations to a single slot. pub(crate) fn is_final(&self) -> bool { - self.ready_slots.load(Acquire) & READY_MASK == READY_MASK + self.header.ready_slots.load(Acquire) & READY_MASK == READY_MASK } /// Returns the `observed_tail_position` value, if set pub(crate) fn observed_tail_position(&self) -> Option<usize> { - if 0 == RELEASED & self.ready_slots.load(Acquire) { + if 0 == RELEASED & self.header.ready_slots.load(Acquire) { None } else { - Some(self.observed_tail_position.with(|ptr| unsafe { *ptr })) + Some( + self.header + .observed_tail_position + .with(|ptr| unsafe { *ptr }), + ) } } /// Loads the next block pub(crate) fn load_next(&self, ordering: Ordering) -> Option<NonNull<Block<T>>> { - let ret = NonNull::new(self.next.load(ordering)); + let ret = NonNull::new(self.header.next.load(ordering)); debug_assert!(unsafe { - ret.map(|block| block.as_ref().start_index == self.start_index.wrapping_add(BLOCK_CAP)) - .unwrap_or(true) + ret.map(|block| { + block.as_ref().header.start_index == self.header.start_index.wrapping_add(BLOCK_CAP) + }) + .unwrap_or(true) }); ret @@ -261,9 +303,10 @@ impl<T> Block<T> { success: Ordering, failure: Ordering, ) -> Result<(), NonNull<Block<T>>> { - block.as_mut().start_index = self.start_index.wrapping_add(BLOCK_CAP); + block.as_mut().header.start_index = self.header.start_index.wrapping_add(BLOCK_CAP); let next_ptr = self + .header .next .compare_exchange(ptr::null_mut(), block.as_ptr(), success, failure) .unwrap_or_else(|x| x); @@ -292,7 +335,7 @@ impl<T> Block<T> { // Create the new block. It is assumed that the block will become the // next one after `&self`. If this turns out to not be the case, // `start_index` is updated accordingly. - let new_block = Box::new(Block::new(self.start_index + BLOCK_CAP)); + let new_block = Block::new(self.header.start_index + BLOCK_CAP); let mut new_block = unsafe { NonNull::new_unchecked(Box::into_raw(new_block)) }; @@ -309,7 +352,8 @@ impl<T> Block<T> { // `Release` ensures that the newly allocated block is available to // other threads acquiring the next pointer. let next = NonNull::new( - self.next + self.header + .next .compare_exchange(ptr::null_mut(), new_block.as_ptr(), AcqRel, Acquire) .unwrap_or_else(|x| x), ); @@ -344,8 +388,7 @@ impl<T> Block<T> { Err(curr) => curr, }; - // When running outside of loom, this calls `spin_loop_hint`. - thread::yield_now(); + crate::loom::thread::yield_now(); } } } @@ -362,19 +405,20 @@ fn is_tx_closed(bits: usize) -> bool { } impl<T> Values<T> { - unsafe fn uninitialized() -> Values<T> { - let mut vals = MaybeUninit::uninit(); - + /// Initialize a `Values` struct from a pointer. + /// + /// # Safety + /// + /// The raw pointer must be valid for writing a `Values<T>`. + unsafe fn initialize(_value: NonNull<Values<T>>) { // When fuzzing, `UnsafeCell` needs to be initialized. if_loom! { - let p = vals.as_mut_ptr() as *mut UnsafeCell<MaybeUninit<T>>; + let p = _value.as_ptr() as *mut UnsafeCell<MaybeUninit<T>>; for i in 0..BLOCK_CAP { p.add(i) .write(UnsafeCell::new(MaybeUninit::uninit())); } } - - Values(vals.assume_init()) } } @@ -385,3 +429,20 @@ impl<T> ops::Index<usize> for Values<T> { self.0.index(index) } } + +#[cfg(all(test, not(loom)))] +#[test] +fn assert_no_stack_overflow() { + // https://github.com/tokio-rs/tokio/issues/5293 + + struct Foo { + _a: [u8; 2_000_000], + } + + assert_eq!( + Layout::new::<MaybeUninit<Block<Foo>>>(), + Layout::new::<Block<Foo>>() + ); + + let _block = Block::<Foo>::new(0); +} diff --git a/vendor/tokio/src/sync/mpsc/bounded.rs b/vendor/tokio/src/sync/mpsc/bounded.rs index d7af17251..e870ae5f4 100644 --- a/vendor/tokio/src/sync/mpsc/bounded.rs +++ b/vendor/tokio/src/sync/mpsc/bounded.rs @@ -1,6 +1,7 @@ +use crate::loom::sync::Arc; use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{SendError, TrySendError}; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -10,19 +11,53 @@ cfg_time! { use std::fmt; use std::task::{Context, Poll}; -/// Send values to the associated `Receiver`. +/// Sends values to the associated `Receiver`. /// /// Instances are created by the [`channel`](channel) function. /// -/// To use the `Sender` in a poll function, you can use the [`PollSender`] -/// utility. +/// To convert the `Sender` into a `Sink` or use it in a poll function, you can +/// use the [`PollSender`] utility. /// -/// [`PollSender`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSender.html +/// [`PollSender`]: https://docs.rs/tokio-util/latest/tokio_util/sync/struct.PollSender.html pub struct Sender<T> { chan: chan::Tx<T, Semaphore>, } -/// Permit to send one value into the channel. +/// A sender that does not prevent the channel from being closed. +/// +/// If all [`Sender`] instances of a channel were dropped and only `WeakSender` +/// instances remain, the channel is closed. +/// +/// In order to send messages, the `WeakSender` needs to be upgraded using +/// [`WeakSender::upgrade`], which returns `Option<Sender>`. It returns `None` +/// if all `Sender`s have been dropped, and otherwise it returns a `Sender`. +/// +/// [`Sender`]: Sender +/// [`WeakSender::upgrade`]: WeakSender::upgrade +/// +/// #Examples +/// +/// ``` +/// use tokio::sync::mpsc::channel; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, _rx) = channel::<i32>(15); +/// let tx_weak = tx.downgrade(); +/// +/// // Upgrading will succeed because `tx` still exists. +/// assert!(tx_weak.upgrade().is_some()); +/// +/// // If we drop `tx`, then it will fail. +/// drop(tx); +/// assert!(tx_weak.clone().upgrade().is_none()); +/// } +/// ``` +pub struct WeakSender<T> { + chan: Arc<chan::Chan<T, Semaphore>>, +} + +/// Permits to send one value into the channel. /// /// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] /// and are used to guarantee channel capacity before generating a message to send. @@ -49,7 +84,7 @@ pub struct OwnedPermit<T> { chan: Option<chan::Tx<T, Semaphore>>, } -/// Receive values from the associated `Sender`. +/// Receives values from the associated `Sender`. /// /// Instances are created by the [`channel`](channel) function. /// @@ -57,7 +92,7 @@ pub struct OwnedPermit<T> { /// /// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html pub struct Receiver<T> { - /// The channel receiver + /// The channel receiver. chan: chan::Rx<T, Semaphore>, } @@ -105,9 +140,13 @@ pub struct Receiver<T> { /// } /// } /// ``` +#[track_caller] pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) { assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); - let semaphore = (semaphore::Semaphore::new(buffer), buffer); + let semaphore = Semaphore { + semaphore: semaphore::Semaphore::new(buffer), + bound: buffer, + }; let (tx, rx) = chan::channel(semaphore); let tx = Sender::new(tx); @@ -118,7 +157,11 @@ pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) { /// Channel semaphore is a tuple of the semaphore implementation and a `usize` /// representing the channel bound. -type Semaphore = (semaphore::Semaphore, usize); +#[derive(Debug)] +pub(crate) struct Semaphore { + pub(crate) semaphore: semaphore::Semaphore, + pub(crate) bound: usize, +} impl<T> Receiver<T> { pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> Receiver<T> { @@ -187,6 +230,50 @@ impl<T> Receiver<T> { poll_fn(|cx| self.chan.recv(cx)).await } + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").await.unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// This method returns `None` if the channel has been closed and there are @@ -237,7 +324,9 @@ impl<T> Receiver<T> { /// sync_code.join().unwrap() /// } /// ``` + #[track_caller] #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(alias = "recv_blocking"))] pub fn blocking_recv(&mut self) -> Option<T> { crate::future::block_on(self.recv()) } @@ -291,7 +380,7 @@ impl<T> Receiver<T> { /// This method returns: /// /// * `Poll::Pending` if no messages are available but the channel is not - /// closed. + /// closed, or if a spurious failure happens. /// * `Poll::Ready(Some(message))` if a message is available. /// * `Poll::Ready(None)` if the channel has been closed and all messages /// sent before it was closed have been received. @@ -301,6 +390,12 @@ impl<T> Receiver<T> { /// receiver, or when the channel is closed. Note that on multiple calls to /// `poll_recv`, only the `Waker` from the `Context` passed to the most /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -485,7 +580,7 @@ impl<T> Sender<T> { /// } /// ``` pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { - match self.chan.semaphore().0.try_acquire(1) { + match self.chan.semaphore().semaphore.try_acquire(1) { Ok(_) => {} Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), @@ -513,6 +608,11 @@ impl<T> Sender<T> { /// [`close`]: Receiver::close /// [`Receiver`]: Receiver /// + /// # Panics + /// + /// This function panics if it is called outside the context of a Tokio + /// runtime [with time enabled](crate::runtime::Builder::enable_time). + /// /// # Examples /// /// In the following example, each call to `send_timeout` will block until the @@ -595,7 +695,9 @@ impl<T> Sender<T> { /// sync_code.join().unwrap() /// } /// ``` + #[track_caller] #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(alias = "send_blocking"))] pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> { crate::future::block_on(self.send(value)) } @@ -622,7 +724,7 @@ impl<T> Sender<T> { self.chan.is_closed() } - /// Wait for channel capacity. Once capacity to send one message is + /// Waits for channel capacity. Once capacity to send one message is /// available, it is reserved for the caller. /// /// If the channel is full, the function waits for the number of unreceived @@ -671,7 +773,7 @@ impl<T> Sender<T> { Ok(Permit { chan: &self.chan }) } - /// Wait for channel capacity, moving the `Sender` and returning an owned + /// Waits for channel capacity, moving the `Sender` and returning an owned /// permit. Once capacity to send one message is available, it is reserved /// for the caller. /// @@ -759,13 +861,15 @@ impl<T> Sender<T> { } async fn reserve_inner(&self) -> Result<(), SendError<()>> { - match self.chan.semaphore().0.acquire(1).await { + crate::trace::async_trace_leaf().await; + + match self.chan.semaphore().semaphore.acquire(1).await { Ok(_) => Ok(()), Err(_) => Err(SendError(())), } } - /// Try to acquire a slot in the channel without waiting for the slot to become + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available. /// /// If the channel is full this function will return [`TrySendError`], otherwise @@ -809,15 +913,16 @@ impl<T> Sender<T> { /// } /// ``` pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> { - match self.chan.semaphore().0.try_acquire(1) { + match self.chan.semaphore().semaphore.try_acquire(1) { Ok(_) => {} - Err(_) => return Err(TrySendError::Full(())), + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())), } Ok(Permit { chan: &self.chan }) } - /// Try to acquire a slot in the channel without waiting for the slot to become + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available, returning an owned permit. /// /// This moves the sender _by value_, and returns an owned permit that can @@ -873,9 +978,10 @@ impl<T> Sender<T> { /// } /// ``` pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> { - match self.chan.semaphore().0.try_acquire(1) { + match self.chan.semaphore().semaphore.try_acquire(1) { Ok(_) => {} - Err(_) => return Err(TrySendError::Full(self)), + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(self)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(self)), } Ok(OwnedPermit { @@ -903,6 +1009,8 @@ impl<T> Sender<T> { /// /// The capacity goes down when sending a value by calling [`send`] or by reserving capacity /// with [`reserve`]. The capacity goes up when values are received by the [`Receiver`]. + /// This is distinct from [`max_capacity`], which always returns buffer capacity initially + /// specified when calling [`channel`] /// /// # Examples /// @@ -928,8 +1036,56 @@ impl<T> Sender<T> { /// /// [`send`]: Sender::send /// [`reserve`]: Sender::reserve + /// [`channel`]: channel + /// [`max_capacity`]: Sender::max_capacity pub fn capacity(&self) -> usize { - self.chan.semaphore().0.available_permits() + self.chan.semaphore().semaphore.available_permits() + } + + /// Converts the `Sender` to a [`WeakSender`] that does not count + /// towards RAII semantics, i.e. if all `Sender` instances of the + /// channel were dropped and only `WeakSender` instances remain, + /// the channel is closed. + pub fn downgrade(&self) -> WeakSender<T> { + WeakSender { + chan: self.chan.downgrade(), + } + } + + /// Returns the maximum buffer capacity of the channel. + /// + /// The maximum capacity is the buffer capacity initially specified when calling + /// [`channel`]. This is distinct from [`capacity`], which returns the *current* + /// available buffer capacity: as messages are sent and received, the + /// value returned by [`capacity`] will go up or down, whereas the value + /// returned by `max_capacity` will remain constant. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = mpsc::channel::<()>(5); + /// + /// // both max capacity and capacity are the same at first + /// assert_eq!(tx.max_capacity(), 5); + /// assert_eq!(tx.capacity(), 5); + /// + /// // Making a reservation doesn't change the max capacity. + /// let permit = tx.reserve().await.unwrap(); + /// assert_eq!(tx.max_capacity(), 5); + /// // but drops the capacity by one + /// assert_eq!(tx.capacity(), 4); + /// } + /// ``` + /// + /// [`channel`]: channel + /// [`max_capacity`]: Sender::max_capacity + /// [`capacity`]: Sender::capacity + pub fn max_capacity(&self) -> usize { + self.chan.semaphore().bound } } @@ -949,6 +1105,29 @@ impl<T> fmt::Debug for Sender<T> { } } +impl<T> Clone for WeakSender<T> { + fn clone(&self) -> Self { + WeakSender { + chan: self.chan.clone(), + } + } +} + +impl<T> WeakSender<T> { + /// Tries to convert a WeakSender into a [`Sender`]. This will return `Some` + /// if there are other `Sender` instances alive and the channel wasn't + /// previously dropped, otherwise `None` is returned. + pub fn upgrade(&self) -> Option<Sender<T>> { + chan::Tx::upgrade(self.chan.clone()).map(Sender::new) + } +} + +impl<T> fmt::Debug for WeakSender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("WeakSender").finish() + } +} + // ===== impl Permit ===== impl<T> Permit<'_, T> { @@ -1065,7 +1244,7 @@ impl<T> OwnedPermit<T> { Sender { chan } } - /// Release the reserved capacity *without* sending a message, returning the + /// Releases the reserved capacity *without* sending a message, returning the /// [`Sender`]. /// /// # Examples diff --git a/vendor/tokio/src/sync/mpsc/chan.rs b/vendor/tokio/src/sync/mpsc/chan.rs index 554d02284..6f87715dd 100644 --- a/vendor/tokio/src/sync/mpsc/chan.rs +++ b/vendor/tokio/src/sync/mpsc/chan.rs @@ -2,16 +2,18 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::list; +use crate::runtime::park::CachedParkThread; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::{bounded, list, unbounded}; use crate::sync::notify::Notify; use std::fmt; use std::process; -use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; -/// Channel sender +/// Channel sender. pub(crate) struct Tx<T, S> { inner: Arc<Chan<T, S>>, } @@ -22,7 +24,7 @@ impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { } } -/// Channel receiver +/// Channel receiver. pub(crate) struct Rx<T, S: Semaphore> { inner: Arc<Chan<T, S>>, } @@ -43,8 +45,8 @@ pub(crate) trait Semaphore { fn is_closed(&self) -> bool; } -struct Chan<T, S> { - /// Notifies all tasks listening for the receiver being dropped +pub(super) struct Chan<T, S> { + /// Notifies all tasks listening for the receiver being dropped. notify_rx_closed: Notify, /// Handle to the push half of the lock-free list. @@ -126,6 +128,30 @@ impl<T, S> Tx<T, S> { Tx { inner: chan } } + pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> { + self.inner.clone() + } + + // Returns the upgraded channel or None if the upgrade failed. + pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> { + let mut tx_count = chan.tx_count.load(Acquire); + + loop { + if tx_count == 0 { + // channel is closed + return None; + } + + match chan + .tx_count + .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire) + { + Ok(_) => return Some(Tx { inner: chan }), + Err(prev_count) => tx_count = prev_count, + } + } + } + pub(super) fn semaphore(&self) -> &S { &self.inner.semaphore } @@ -216,8 +242,10 @@ impl<T, S: Semaphore> Rx<T, S> { pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { use super::block::Read::*; + ready!(crate::trace::trace_leaf(cx)); + // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); self.inner.rx_fields.with_mut(|rx_fields_ptr| { let rx_fields = unsafe { &mut *rx_fields_ptr }; @@ -263,6 +291,51 @@ impl<T, S: Semaphore> Rx<T, S> { } }) } + + /// Try to receive the next value. + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::list::TryPopResult; + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.try_pop(&self.inner.tx) { + TryPopResult::Ok(value) => { + self.inner.semaphore.add_permit(); + return Ok(value); + } + TryPopResult::Closed => return Err(TryRecvError::Disconnected), + TryPopResult::Empty => return Err(TryRecvError::Empty), + TryPopResult::Busy => {} // fall through + } + }; + } + + try_recv!(); + + // If a previous `poll_recv` call has set a waker, we wake it here. + // This allows us to put our own CachedParkThread waker in the + // AtomicWaker slot instead. + // + // This is not a spurious wakeup to `poll_recv` since we just got a + // Busy from `try_pop`, which only happens if there are messages in + // the queue. + self.inner.rx_waker.wake(); + + // Park the thread until the problematic send has completed. + let mut park = CachedParkThread::new(); + let waker = park.waker().unwrap(); + loop { + self.inner.rx_waker.register_by_ref(&waker); + // It is possible that the problematic send has now completed, + // so we have to check for messages again. + try_recv!(); + park.park(); + } + }) + } } impl<T, S: Semaphore> Drop for Rx<T, S> { @@ -297,7 +370,7 @@ impl<T, S> Drop for Chan<T, S> { fn drop(&mut self) { use super::block::Read::Value; - // Safety: the only owner of the rx fields is Chan, and eing + // Safety: the only owner of the rx fields is Chan, and being // inside its own Drop means we're the last ones to touch it. self.rx_fields.with_mut(|rx_fields_ptr| { let rx_fields = unsafe { &mut *rx_fields_ptr }; @@ -310,32 +383,29 @@ impl<T, S> Drop for Chan<T, S> { // ===== impl Semaphore for (::Semaphore, capacity) ===== -impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { +impl Semaphore for bounded::Semaphore { fn add_permit(&self) { - self.0.release(1) + self.semaphore.release(1) } fn is_idle(&self) -> bool { - self.0.available_permits() == self.1 + self.semaphore.available_permits() == self.bound } fn close(&self) { - self.0.close(); + self.semaphore.close(); } fn is_closed(&self) -> bool { - self.0.is_closed() + self.semaphore.is_closed() } } // ===== impl Semaphore for AtomicUsize ===== -use std::sync::atomic::Ordering::{Acquire, Release}; -use std::usize; - -impl Semaphore for AtomicUsize { +impl Semaphore for unbounded::Semaphore { fn add_permit(&self) { - let prev = self.fetch_sub(2, Release); + let prev = self.0.fetch_sub(2, Release); if prev >> 1 == 0 { // Something went wrong @@ -344,14 +414,14 @@ impl Semaphore for AtomicUsize { } fn is_idle(&self) -> bool { - self.load(Acquire) >> 1 == 0 + self.0.load(Acquire) >> 1 == 0 } fn close(&self) { - self.fetch_or(1, Release); + self.0.fetch_or(1, Release); } fn is_closed(&self) -> bool { - self.load(Acquire) & 1 == 1 + self.0.load(Acquire) & 1 == 1 } } diff --git a/vendor/tokio/src/sync/mpsc/error.rs b/vendor/tokio/src/sync/mpsc/error.rs index 0d25ad386..25b4455be 100644 --- a/vendor/tokio/src/sync/mpsc/error.rs +++ b/vendor/tokio/src/sync/mpsc/error.rs @@ -1,25 +1,31 @@ -//! Channel error types +//! Channel error types. use std::error::Error; use std::fmt; /// Error returned by the `Sender`. -#[derive(Debug)] +#[derive(PartialEq, Eq, Clone, Copy)] pub struct SendError<T>(pub T); +impl<T> fmt::Debug for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendError").finish_non_exhaustive() + } +} + impl<T> fmt::Display for SendError<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } -impl<T: fmt::Debug> std::error::Error for SendError<T> {} +impl<T> std::error::Error for SendError<T> {} // ===== TrySendError ===== /// This enumeration is the list of the possible error outcomes for the /// [try_send](super::Sender::try_send) method. -#[derive(Debug)] +#[derive(PartialEq, Eq, Clone, Copy)] pub enum TrySendError<T> { /// The data could not be sent on the channel because the channel is /// currently full and sending would require blocking. @@ -30,7 +36,14 @@ pub enum TrySendError<T> { Closed(T), } -impl<T: fmt::Debug> Error for TrySendError<T> {} +impl<T> fmt::Debug for TrySendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TrySendError::Full(..) => "Full(..)".fmt(f), + TrySendError::Closed(..) => "Closed(..)".fmt(f), + } + } +} impl<T> fmt::Display for TrySendError<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -45,16 +58,42 @@ impl<T> fmt::Display for TrySendError<T> { } } +impl<T> Error for TrySendError<T> {} + impl<T> From<SendError<T>> for TrySendError<T> { fn from(src: SendError<T>) -> TrySendError<T> { TrySendError::Closed(src.0) } } +// ===== TryRecvError ===== + +/// Error returned by `try_recv`. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// This **channel** is currently empty, but the **Sender**(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The **channel**'s sending half has become disconnected, and there will + /// never be any more data received on it. + Disconnected, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), + } + } +} + +impl Error for TryRecvError {} + // ===== RecvError ===== /// Error returned by `Receiver`. -#[derive(Debug)] +#[derive(Debug, Clone)] #[doc(hidden)] #[deprecated(note = "This type is unused because recv returns an Option.")] pub struct RecvError(()); @@ -72,7 +111,7 @@ impl Error for RecvError {} cfg_time! { // ===== SendTimeoutError ===== - #[derive(Debug)] + #[derive(PartialEq, Eq, Clone, Copy)] /// Error returned by [`Sender::send_timeout`](super::Sender::send_timeout)]. pub enum SendTimeoutError<T> { /// The data could not be sent on the channel because the channel is @@ -84,7 +123,14 @@ cfg_time! { Closed(T), } - impl<T: fmt::Debug> Error for SendTimeoutError<T> {} + impl<T> fmt::Debug for SendTimeoutError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + SendTimeoutError::Timeout(..) => "Timeout(..)".fmt(f), + SendTimeoutError::Closed(..) => "Closed(..)".fmt(f), + } + } + } impl<T> fmt::Display for SendTimeoutError<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -98,4 +144,6 @@ cfg_time! { ) } } + + impl<T> Error for SendTimeoutError<T> {} } diff --git a/vendor/tokio/src/sync/mpsc/list.rs b/vendor/tokio/src/sync/mpsc/list.rs index 5dad2babf..10b29575b 100644 --- a/vendor/tokio/src/sync/mpsc/list.rs +++ b/vendor/tokio/src/sync/mpsc/list.rs @@ -8,31 +8,43 @@ use std::fmt; use std::ptr::NonNull; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; -/// List queue transmit handle +/// List queue transmit handle. pub(crate) struct Tx<T> { /// Tail in the `Block` mpmc list. block_tail: AtomicPtr<Block<T>>, - /// Position to push the next message. This reference a block and offset + /// Position to push the next message. This references a block and offset /// into the block. tail_position: AtomicUsize, } /// List queue receive handle pub(crate) struct Rx<T> { - /// Pointer to the block being processed + /// Pointer to the block being processed. head: NonNull<Block<T>>, - /// Next slot index to process + /// Next slot index to process. index: usize, - /// Pointer to the next block pending release + /// Pointer to the next block pending release. free_head: NonNull<Block<T>>, } +/// Return value of `Rx::try_pop`. +pub(crate) enum TryPopResult<T> { + /// Successfully popped a value. + Ok(T), + /// The channel is empty. + Empty, + /// The channel is empty and closed. + Closed, + /// The channel is not empty, but the first value is being written. + Busy, +} + pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) { // Create the initial block shared between the tx and rx halves. - let initial_block = Box::new(Block::new(0)); + let initial_block = Block::new(0); let initial_block_ptr = Box::into_raw(initial_block); let tx = Tx { @@ -67,7 +79,7 @@ impl<T> Tx<T> { } } - /// Closes the send half of the list + /// Closes the send half of the list. /// /// Similar process as pushing a value, but instead of writing the value & /// setting the ready flag, the TX_CLOSED flag is set on the block. @@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> { } impl<T> Rx<T> { - /// Pops the next value off the queue + /// Pops the next value off the queue. pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> { // Advance `head`, if needed if !self.try_advancing_head() { @@ -240,6 +252,26 @@ impl<T> Rx<T> { } } + /// Pops the next value off the queue, detecting whether the block + /// is busy or empty on failure. + /// + /// This function exists because `Rx::pop` can return `None` even if the + /// channel's queue contains a message that has been completely written. + /// This can happen if the fully delivered message is behind another message + /// that is in the middle of being written to the block, since the channel + /// can't return the messages out of order. + pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> { + let tail_position = tx.tail_position.load(Acquire); + let result = self.pop(tx); + + match result { + Some(block::Read::Value(t)) => TryPopResult::Ok(t), + Some(block::Read::Closed) => TryPopResult::Closed, + None if tail_position == self.index => TryPopResult::Empty, + None => TryPopResult::Busy, + } + } + /// Tries advancing the block pointer to the block referenced by `self.index`. /// /// Returns `true` if successful, `false` if there is no next block to load. diff --git a/vendor/tokio/src/sync/mpsc/mod.rs b/vendor/tokio/src/sync/mpsc/mod.rs index 879e3dcfc..b2af084b2 100644 --- a/vendor/tokio/src/sync/mpsc/mod.rs +++ b/vendor/tokio/src/sync/mpsc/mod.rs @@ -21,6 +21,9 @@ //! when additional capacity is available. In other words, the channel provides //! backpressure. //! +//! This channel is also suitable for the single-producer single-consumer +//! use-case. (Unless you only need to send one message, in which case you +//! should use the [oneshot] channel.) //! //! # Disconnection //! @@ -30,7 +33,8 @@ //! //! If the [`Receiver`] handle is dropped, then messages can no longer //! be read out of the channel. In this case, all further attempts to send will -//! result in an error. +//! result in an error. Additionally, all unread messages will be drained from the +//! channel and dropped. //! //! # Clean Shutdown //! @@ -58,6 +62,22 @@ //! [crossbeam][crossbeam-unbounded]. Similarly, for sending a message _from sync //! to async_, you should use an unbounded Tokio `mpsc` channel. //! +//! Please be aware that the above remarks were written with the `mpsc` channel +//! in mind, but they can also be generalized to other kinds of channels. In +//! general, any channel method that isn't marked async can be called anywhere, +//! including outside of the runtime. For example, sending a message on a +//! [oneshot] channel from outside the runtime is perfectly fine. +//! +//! # Multiple runtimes +//! +//! The mpsc channel does not care about which runtime you use it in, and can be +//! used to send messages from one runtime to another. It can also be used in +//! non-Tokio runtimes. +//! +//! There is one exception to the above: the [`send_timeout`] must be used from +//! within a Tokio runtime, however it is still not tied to one specific Tokio +//! runtime, and the sender may be moved from one Tokio runtime to another. +//! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver //! [bounded-send]: crate::sync::mpsc::Sender::send() @@ -66,21 +86,25 @@ //! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv() //! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender //! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver +//! [oneshot]: crate::sync::oneshot //! [`Handle::block_on`]: crate::runtime::Handle::block_on() //! [std-unbounded]: std::sync::mpsc::channel //! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html +//! [`send_timeout`]: crate::sync::mpsc::Sender::send_timeout pub(super) mod block; mod bounded; -pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender}; +pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender}; mod chan; pub(super) mod list; mod unbounded; -pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +pub use self::unbounded::{ + unbounded_channel, UnboundedReceiver, UnboundedSender, WeakUnboundedSender, +}; pub mod error; diff --git a/vendor/tokio/src/sync/mpsc/unbounded.rs b/vendor/tokio/src/sync/mpsc/unbounded.rs index 23c80f60a..cd83fc125 100644 --- a/vendor/tokio/src/sync/mpsc/unbounded.rs +++ b/vendor/tokio/src/sync/mpsc/unbounded.rs @@ -1,6 +1,6 @@ -use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{atomic::AtomicUsize, Arc}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::SendError; +use crate::sync::mpsc::error::{SendError, TryRecvError}; use std::fmt; use std::task::{Context, Poll}; @@ -13,6 +13,40 @@ pub struct UnboundedSender<T> { chan: chan::Tx<T, Semaphore>, } +/// An unbounded sender that does not prevent the channel from being closed. +/// +/// If all [`UnboundedSender`] instances of a channel were dropped and only +/// `WeakUnboundedSender` instances remain, the channel is closed. +/// +/// In order to send messages, the `WeakUnboundedSender` needs to be upgraded using +/// [`WeakUnboundedSender::upgrade`], which returns `Option<UnboundedSender>`. It returns `None` +/// if all `UnboundedSender`s have been dropped, and otherwise it returns an `UnboundedSender`. +/// +/// [`UnboundedSender`]: UnboundedSender +/// [`WeakUnboundedSender::upgrade`]: WeakUnboundedSender::upgrade +/// +/// #Examples +/// +/// ``` +/// use tokio::sync::mpsc::unbounded_channel; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, _rx) = unbounded_channel::<i32>(); +/// let tx_weak = tx.downgrade(); +/// +/// // Upgrading will succeed because `tx` still exists. +/// assert!(tx_weak.upgrade().is_some()); +/// +/// // If we drop `tx`, then it will fail. +/// drop(tx); +/// assert!(tx_weak.clone().upgrade().is_none()); +/// } +/// ``` +pub struct WeakUnboundedSender<T> { + chan: Arc<chan::Chan<T, Semaphore>>, +} + impl<T> Clone for UnboundedSender<T> { fn clone(&self) -> Self { UnboundedSender { @@ -61,7 +95,7 @@ impl<T> fmt::Debug for UnboundedReceiver<T> { /// the channel. Using an `unbounded` channel has the ability of causing the /// process to run out of memory. In this case, the process will be aborted. pub fn unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) { - let (tx, rx) = chan::channel(AtomicUsize::new(0)); + let (tx, rx) = chan::channel(Semaphore(AtomicUsize::new(0))); let tx = UnboundedSender::new(tx); let rx = UnboundedReceiver::new(rx); @@ -70,7 +104,8 @@ pub fn unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) { } /// No capacity -type Semaphore = AtomicUsize; +#[derive(Debug)] +pub(crate) struct Semaphore(pub(crate) AtomicUsize); impl<T> UnboundedReceiver<T> { pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> UnboundedReceiver<T> { @@ -79,8 +114,14 @@ impl<T> UnboundedReceiver<T> { /// Receives the next value for this receiver. /// - /// `None` is returned when all `Sender` halves have dropped, indicating - /// that no further values can be sent on the channel. + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. /// /// # Cancel safety /// @@ -89,6 +130,8 @@ impl<T> UnboundedReceiver<T> { /// completes first, it is guaranteed that no messages were received on this /// channel. /// + /// [`close`]: Self::close + /// /// # Examples /// /// ``` @@ -129,6 +172,50 @@ impl<T> UnboundedReceiver<T> { poll_fn(|cx| self.poll_recv(cx)).await } + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// # Panics @@ -154,7 +241,9 @@ impl<T> UnboundedReceiver<T> { /// sync_code.join().unwrap(); /// } /// ``` + #[track_caller] #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(alias = "recv_blocking"))] pub fn blocking_recv(&mut self) -> Option<T> { crate::future::block_on(self.recv()) } @@ -163,6 +252,9 @@ impl<T> UnboundedReceiver<T> { /// /// This prevents any further messages from being sent on the channel while /// still enabling the receiver to drain messages that are buffered. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. pub fn close(&mut self) { self.chan.close(); } @@ -172,7 +264,7 @@ impl<T> UnboundedReceiver<T> { /// This method returns: /// /// * `Poll::Pending` if no messages are available but the channel is not - /// closed. + /// closed, or if a spurious failure happens. /// * `Poll::Ready(Some(message))` if a message is available. /// * `Poll::Ready(None)` if the channel has been closed and all messages /// sent before it was closed have been received. @@ -182,6 +274,12 @@ impl<T> UnboundedReceiver<T> { /// receiver, or when the channel is closed. Note that on multiple calls to /// `poll_recv`, only the `Waker` from the `Context` passed to the most /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -217,7 +315,7 @@ impl<T> UnboundedSender<T> { use std::process; use std::sync::atomic::Ordering::{AcqRel, Acquire}; - let mut curr = self.chan.semaphore().load(Acquire); + let mut curr = self.chan.semaphore().0.load(Acquire); loop { if curr & 1 == 1 { @@ -233,6 +331,7 @@ impl<T> UnboundedSender<T> { match self .chan .semaphore() + .0 .compare_exchange(curr, curr + 2, AcqRel, Acquire) { Ok(_) => return true, @@ -320,4 +419,37 @@ impl<T> UnboundedSender<T> { pub fn same_channel(&self, other: &Self) -> bool { self.chan.same_channel(&other.chan) } + + /// Converts the `UnboundedSender` to a [`WeakUnboundedSender`] that does not count + /// towards RAII semantics, i.e. if all `UnboundedSender` instances of the + /// channel were dropped and only `WeakUnboundedSender` instances remain, + /// the channel is closed. + pub fn downgrade(&self) -> WeakUnboundedSender<T> { + WeakUnboundedSender { + chan: self.chan.downgrade(), + } + } +} + +impl<T> Clone for WeakUnboundedSender<T> { + fn clone(&self) -> Self { + WeakUnboundedSender { + chan: self.chan.clone(), + } + } +} + +impl<T> WeakUnboundedSender<T> { + /// Tries to convert a WeakUnboundedSender into an [`UnboundedSender`]. + /// This will return `Some` if there are other `Sender` instances alive and + /// the channel wasn't previously dropped, otherwise `None` is returned. + pub fn upgrade(&self) -> Option<UnboundedSender<T>> { + chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new) + } +} + +impl<T> fmt::Debug for WeakUnboundedSender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("WeakUnboundedSender").finish() + } } diff --git a/vendor/tokio/src/sync/mutex.rs b/vendor/tokio/src/sync/mutex.rs index 8ae824770..549c77b32 100644 --- a/vendor/tokio/src/sync/mutex.rs +++ b/vendor/tokio/src/sync/mutex.rs @@ -1,12 +1,15 @@ #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] use crate::sync::batch_semaphore as semaphore; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; use std::cell::UnsafeCell; use std::error::Error; +use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use std::{fmt, marker, mem}; +use std::{fmt, mem, ptr}; /// An asynchronous `Mutex`-like type. /// @@ -124,6 +127,8 @@ use std::{fmt, marker, mem}; /// [`Send`]: trait@std::marker::Send /// [`lock`]: method@Mutex::lock pub struct Mutex<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, s: semaphore::Semaphore, c: UnsafeCell<T>, } @@ -137,7 +142,13 @@ pub struct Mutex<T: ?Sized> { /// /// The lock is automatically released whenever the guard is dropped, at which /// point `lock` will succeed yet again. +#[clippy::has_significant_drop] +#[must_use = "if unused the Mutex will immediately unlock"] pub struct MutexGuard<'a, T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, lock: &'a Mutex<T>, } @@ -156,7 +167,12 @@ pub struct MutexGuard<'a, T: ?Sized> { /// point `lock` will succeed yet again. /// /// [`Arc`]: std::sync::Arc +#[clippy::has_significant_drop] pub struct OwnedMutexGuard<T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, lock: Arc<Mutex<T>>, } @@ -165,12 +181,71 @@ pub struct OwnedMutexGuard<T: ?Sized> { /// This can be used to hold a subfield of the protected data. /// /// [`MutexGuard::map`]: method@MutexGuard::map +#[clippy::has_significant_drop] #[must_use = "if unused the Mutex will immediately unlock"] pub struct MappedMutexGuard<'a, T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, s: &'a semaphore::Semaphore, data: *mut T, // Needed to tell the borrow checker that we are holding a `&mut T` - marker: marker::PhantomData<&'a mut T>, + marker: PhantomData<&'a mut T>, +} + +/// A owned handle to a held `Mutex` that has had a function applied to it via +/// [`OwnedMutexGuard::map`]. +/// +/// This can be used to hold a subfield of the protected data. +/// +/// [`OwnedMutexGuard::map`]: method@OwnedMutexGuard::map +#[clippy::has_significant_drop] +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct OwnedMappedMutexGuard<T: ?Sized, U: ?Sized = T> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + data: *mut U, + lock: Arc<Mutex<T>>, +} + +/// A helper type used when taking apart a `MutexGuard` without running its +/// Drop implementation. +#[allow(dead_code)] // Unused fields are still used in Drop. +struct MutexGuardInner<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: &'a Mutex<T>, +} + +/// A helper type used when taking apart a `OwnedMutexGuard` without running +/// its Drop implementation. +struct OwnedMutexGuardInner<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: Arc<Mutex<T>>, +} + +/// A helper type used when taking apart a `MappedMutexGuard` without running +/// its Drop implementation. +#[allow(dead_code)] // Unused fields are still used in Drop. +struct MappedMutexGuardInner<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + s: &'a semaphore::Semaphore, + data: *mut T, +} + +/// A helper type used when taking apart a `OwnedMappedMutexGuard` without running +/// its Drop implementation. +#[allow(dead_code)] // Unused fields are still used in Drop. +struct OwnedMappedMutexGuardInner<T: ?Sized, U: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + data: *mut U, + lock: Arc<Mutex<T>>, } // As long as T: Send, it's fine to send and share Mutex<T> between threads. @@ -183,6 +258,19 @@ unsafe impl<T> Sync for OwnedMutexGuard<T> where T: ?Sized + Send + Sync {} unsafe impl<'a, T> Sync for MappedMutexGuard<'a, T> where T: ?Sized + Sync + 'a {} unsafe impl<'a, T> Send for MappedMutexGuard<'a, T> where T: ?Sized + Send + 'a {} +unsafe impl<T, U> Sync for OwnedMappedMutexGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +unsafe impl<T, U> Send for OwnedMappedMutexGuard<T, U> +where + T: ?Sized + Send, + U: ?Sized + Send, +{ +} + /// Error returned from the [`Mutex::try_lock`], [`RwLock::try_read`] and /// [`RwLock::try_write`] functions. /// @@ -191,8 +279,8 @@ unsafe impl<'a, T> Send for MappedMutexGuard<'a, T> where T: ?Sized + Send + 'a /// `RwLock::try_read` operation will only fail if the lock is currently held /// by an exclusive writer. /// -/// `RwLock::try_write` operation will if lock is held by any reader or by an -/// exclusive writer. +/// `RwLock::try_write` operation will only fail if the lock is currently held +/// by any reader or by an exclusive writer. /// /// [`Mutex::try_lock`]: Mutex::try_lock /// [`RwLock::try_read`]: fn@super::RwLock::try_read @@ -242,13 +330,42 @@ impl<T: ?Sized> Mutex<T> { /// /// let lock = Mutex::new(5); /// ``` + #[track_caller] pub fn new(t: T) -> Self where T: Sized, { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Mutex", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + semaphore::Semaphore::new(1) + }); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = semaphore::Semaphore::new(1); + Self { c: UnsafeCell::new(t), - s: semaphore::Semaphore::new(1), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -270,6 +387,8 @@ impl<T: ?Sized> Mutex<T> { Self { c: UnsafeCell::new(t), s: semaphore::Semaphore::const_new(1), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } @@ -297,8 +416,147 @@ impl<T: ?Sized> Mutex<T> { /// } /// ``` pub async fn lock(&self) -> MutexGuard<'_, T> { - self.acquire().await; - MutexGuard { lock: self } + let acquire_fut = async { + self.acquire().await; + + MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + self.resource_span.clone(), + "Mutex::lock", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + guard + } + + /// Blockingly locks this `Mutex`. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// let lock = mutex.lock().await; + /// + /// let mutex1 = Arc::clone(&mutex); + /// let blocking_task = tokio::task::spawn_blocking(move || { + /// // This shall block until the `lock` is released. + /// let mut n = mutex1.blocking_lock(); + /// *n = 2; + /// }); + /// + /// assert_eq!(*lock, 1); + /// // Release the lock. + /// drop(lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let n = mutex.try_lock().unwrap(); + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(alias = "lock_blocking"))] + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + crate::future::block_on(self.lock()) + } + + /// Blockingly locks this `Mutex`. When the lock has been acquired, function returns an + /// [`OwnedMutexGuard`]. + /// + /// This method is identical to [`Mutex::blocking_lock`], except that the returned + /// guard references the `Mutex` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `Mutex` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `Mutex` alive by holding an `Arc`. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// let lock = mutex.lock().await; + /// + /// let mutex1 = Arc::clone(&mutex); + /// let blocking_task = tokio::task::spawn_blocking(move || { + /// // This shall block until the `lock` is released. + /// let mut n = mutex1.blocking_lock_owned(); + /// *n = 2; + /// }); + /// + /// assert_eq!(*lock, 1); + /// // Release the lock. + /// drop(lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let n = mutex.try_lock().unwrap(); + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + pub fn blocking_lock_owned(self: Arc<Self>) -> OwnedMutexGuard<T> { + crate::future::block_on(self.lock_owned()) } /// Locks this mutex, causing the current task to yield until the lock has @@ -334,11 +592,45 @@ impl<T: ?Sized> Mutex<T> { /// /// [`Arc`]: std::sync::Arc pub async fn lock_owned(self: Arc<Self>) -> OwnedMutexGuard<T> { - self.acquire().await; - OwnedMutexGuard { lock: self } + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + let acquire_fut = async { + self.acquire().await; + + OwnedMutexGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + lock: self, + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + resource_span, + "Mutex::lock_owned", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + guard } async fn acquire(&self) { + crate::trace::async_trace_leaf().await; + self.s.acquire(1).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and // we own it exclusively, which means that this can never happen. @@ -365,7 +657,23 @@ impl<T: ?Sized> Mutex<T> { /// ``` pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> { match self.s.try_acquire(1) { - Ok(_) => Ok(MutexGuard { lock: self }), + Ok(_) => { + let guard = MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + Ok(guard) + } Err(_) => Err(TryLockError(())), } } @@ -420,7 +728,23 @@ impl<T: ?Sized> Mutex<T> { /// # } pub fn try_lock_owned(self: Arc<Self>) -> Result<OwnedMutexGuard<T>, TryLockError> { match self.s.try_acquire(1) { - Ok(_) => Ok(OwnedMutexGuard { lock: self }), + Ok(_) => { + let guard = OwnedMutexGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + lock: self, + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + Ok(guard) + } Err(_) => Err(TryLockError(())), } } @@ -462,14 +786,14 @@ where } } -impl<T> std::fmt::Debug for Mutex<T> +impl<T: ?Sized> std::fmt::Debug for Mutex<T> where T: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("Mutex"); match self.try_lock() { - Ok(inner) => d.field("data", &*inner), + Ok(inner) => d.field("data", &&*inner), Err(_) => d.field("data", &format_args!("<locked>")), }; d.finish() @@ -479,6 +803,17 @@ where // === impl MutexGuard === impl<'a, T: ?Sized> MutexGuard<'a, T> { + fn skip_drop(self) -> MutexGuardInner<'a, T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the `resource_span` and then forgets the + // original. In the end, we have not duplicated or forgotten any values. + MutexGuardInner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: unsafe { std::ptr::read(&me.resource_span) }, + lock: me.lock, + } + } + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. /// /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. @@ -515,12 +850,13 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { F: FnOnce(&mut T) -> &mut U, { let data = f(&mut *this) as *mut U; - let s = &this.lock.s; - mem::forget(this); + let inner = this.skip_drop(); MappedMutexGuard { - s, + s: &inner.lock.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, } } @@ -565,19 +901,54 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { Some(data) => data as *mut U, None => return Err(this), }; - let s = &this.lock.s; - mem::forget(this); + let inner = this.skip_drop(); Ok(MappedMutexGuard { - s, + s: &inner.lock.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, }) } + + /// Returns a reference to the original `Mutex`. + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// async fn unlock_and_relock<'l>(guard: MutexGuard<'l, u32>) -> MutexGuard<'l, u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex = MutexGuard::mutex(&guard); + /// drop(guard); + /// let guard = mutex.lock().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Mutex::new(0u32); + /// # let guard = mutex.lock().await; + /// # let _guard = unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &'a Mutex<T> { + this.lock + } } impl<T: ?Sized> Drop for MutexGuard<'_, T> { fn drop(&mut self) { - self.lock.s.release(1) + self.lock.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); } } @@ -608,9 +979,156 @@ impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> { // === impl OwnedMutexGuard === +impl<T: ?Sized> OwnedMutexGuard<T> { + fn skip_drop(self) -> OwnedMutexGuardInner<T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + unsafe { + OwnedMutexGuardInner { + lock: ptr::read(&me.lock), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: ptr::read(&me.resource_span), + } + } + } + + /// Makes a new [`OwnedMappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`OwnedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `OwnedMutexGuard::map(...)`. A method + /// would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// use std::sync::Arc; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Arc::new(Mutex::new(Foo(1))); + /// + /// { + /// let mut mapped = OwnedMutexGuard::map(foo.clone().lock_owned().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`OwnedMutexGuard`]: struct@OwnedMutexGuard + /// [`OwnedMappedMutexGuard`]: struct@OwnedMappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> OwnedMappedMutexGuard<T, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let inner = this.skip_drop(); + OwnedMappedMutexGuard { + data, + lock: inner.lock, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, + } + } + + /// Attempts to make a new [`OwnedMappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`OwnedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `OwnedMutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// use std::sync::Arc; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Arc::new(Mutex::new(Foo(1))); + /// + /// { + /// let mut mapped = OwnedMutexGuard::try_map(foo.clone().lock_owned().await, |f| Some(&mut f.0)) + /// .expect("should not fail"); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`OwnedMutexGuard`]: struct@OwnedMutexGuard + /// [`OwnedMappedMutexGuard`]: struct@OwnedMappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<OwnedMappedMutexGuard<T, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let inner = this.skip_drop(); + Ok(OwnedMappedMutexGuard { + data, + lock: inner.lock, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, + }) + } + + /// Returns a reference to the original `Arc<Mutex>`. + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// + /// async fn unlock_and_relock(guard: OwnedMutexGuard<u32>) -> OwnedMutexGuard<u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex: Arc<Mutex<u32>> = OwnedMutexGuard::mutex(&guard).clone(); + /// drop(guard); + /// let guard = mutex.lock_owned().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Arc::new(Mutex::new(0u32)); + /// # let guard = mutex.lock_owned().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &Arc<Mutex<T>> { + &this.lock + } +} + impl<T: ?Sized> Drop for OwnedMutexGuard<T> { fn drop(&mut self) { - self.lock.s.release(1) + self.lock.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); } } @@ -642,6 +1160,16 @@ impl<T: ?Sized + fmt::Display> fmt::Display for OwnedMutexGuard<T> { // === impl MappedMutexGuard === impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { + fn skip_drop(self) -> MappedMutexGuardInner<'a, T> { + let me = mem::ManuallyDrop::new(self); + MappedMutexGuardInner { + s: me.s, + data: me.data, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: unsafe { std::ptr::read(&me.resource_span) }, + } + } + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. /// /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. @@ -656,12 +1184,13 @@ impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { F: FnOnce(&mut T) -> &mut U, { let data = f(&mut *this) as *mut U; - let s = this.s; - mem::forget(this); + let inner = this.skip_drop(); MappedMutexGuard { - s, + s: inner.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, } } @@ -683,19 +1212,28 @@ impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { Some(data) => data as *mut U, None => return Err(this), }; - let s = this.s; - mem::forget(this); + let inner = this.skip_drop(); Ok(MappedMutexGuard { - s, + s: inner.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, }) } } impl<'a, T: ?Sized> Drop for MappedMutexGuard<'a, T> { fn drop(&mut self) { - self.s.release(1) + self.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); } } @@ -723,3 +1261,111 @@ impl<'a, T: ?Sized + fmt::Display> fmt::Display for MappedMutexGuard<'a, T> { fmt::Display::fmt(&**self, f) } } + +// === impl OwnedMappedMutexGuard === + +impl<T: ?Sized, U: ?Sized> OwnedMappedMutexGuard<T, U> { + fn skip_drop(self) -> OwnedMappedMutexGuardInner<T, U> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + unsafe { + OwnedMappedMutexGuardInner { + data: me.data, + lock: ptr::read(&me.lock), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: ptr::read(&me.resource_span), + } + } + } + + /// Makes a new [`OwnedMappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`OwnedMappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `OwnedMappedMutexGuard::map(...)`. A method + /// would interfere with methods of the same name on the contents of the locked data. + /// + /// [`OwnedMappedMutexGuard`]: struct@OwnedMappedMutexGuard + #[inline] + pub fn map<S, F>(mut this: Self, f: F) -> OwnedMappedMutexGuard<T, S> + where + F: FnOnce(&mut U) -> &mut S, + { + let data = f(&mut *this) as *mut S; + let inner = this.skip_drop(); + OwnedMappedMutexGuard { + data, + lock: inner.lock, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, + } + } + + /// Attempts to make a new [`OwnedMappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`OwnedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `OwnedMutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`OwnedMutexGuard`]: struct@OwnedMutexGuard + /// [`OwnedMappedMutexGuard`]: struct@OwnedMappedMutexGuard + #[inline] + pub fn try_map<S, F>(mut this: Self, f: F) -> Result<OwnedMappedMutexGuard<T, S>, Self> + where + F: FnOnce(&mut U) -> Option<&mut S>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut S, + None => return Err(this), + }; + let inner = this.skip_drop(); + Ok(OwnedMappedMutexGuard { + data, + lock: inner.lock, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: inner.resource_span, + }) + } +} + +impl<T: ?Sized, U: ?Sized> Drop for OwnedMappedMutexGuard<T, U> { + fn drop(&mut self) { + self.lock.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); + } +} + +impl<T: ?Sized, U: ?Sized> Deref for OwnedMappedMutexGuard<T, U> { + type Target = U; + fn deref(&self) -> &Self::Target { + unsafe { &*self.data } + } +} + +impl<T: ?Sized, U: ?Sized> DerefMut for OwnedMappedMutexGuard<T, U> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data } + } +} + +impl<T: ?Sized, U: ?Sized + fmt::Debug> fmt::Debug for OwnedMappedMutexGuard<T, U> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized + fmt::Display> fmt::Display for OwnedMappedMutexGuard<T, U> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/vendor/tokio/src/sync/notify.rs b/vendor/tokio/src/sync/notify.rs index af7b9423a..0f104b71a 100644 --- a/vendor/tokio/src/sync/notify.rs +++ b/vendor/tokio/src/sync/notify.rs @@ -5,42 +5,46 @@ // triggers this warning but it is safe to ignore in this case. #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] +use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; -use crate::util::linked_list::{self, LinkedList}; +use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; +use crate::util::WakeList; -use std::cell::UnsafeCell; use std::future::Future; use std::marker::PhantomPinned; +use std::panic::{RefUnwindSafe, UnwindSafe}; use std::pin::Pin; use std::ptr::NonNull; -use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::Ordering::{self, Acquire, Relaxed, Release, SeqCst}; use std::task::{Context, Poll, Waker}; type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; +type GuardedWaitList = GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>; -/// Notify a single task to wake up. +/// Notifies a single task to wake up. /// /// `Notify` provides a basic mechanism to notify a single task of an event. /// `Notify` itself does not carry any data. Instead, it is to be used to signal /// another task to perform an operation. /// -/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. -/// [`notified().await`] waits for a permit to become available, and [`notify_one()`] -/// sets a permit **if there currently are no available permits**. +/// A `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. The +/// [`notified().await`] method waits for a permit to become available, and +/// [`notify_one()`] sets a permit **if there currently are no available +/// permits**. /// /// The synchronization details of `Notify` are similar to /// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] /// value contains a single permit. [`notified().await`] waits for the permit to -/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the -/// permit, waking a pending task if there is one. +/// be made available, consumes the permit, and resumes. [`notify_one()`] sets +/// the permit, waking a pending task if there is one. /// -/// If `notify_one()` is called **before** `notified().await`, then the next call to -/// `notified().await` will complete immediately, consuming the permit. Any -/// subsequent calls to `notified().await` will wait for a new permit. +/// If `notify_one()` is called **before** `notified().await`, then the next +/// call to `notified().await` will complete immediately, consuming the permit. +/// Any subsequent calls to `notified().await` will wait for a new permit. /// -/// If `notify_one()` is called **multiple** times before `notified().await`, only a -/// **single** permit is stored. The next call to `notified().await` will +/// If `notify_one()` is called **multiple** times before `notified().await`, +/// only a **single** permit is stored. The next call to `notified().await` will /// complete immediately, but the one after will wait for a new permit. /// /// # Examples @@ -56,17 +60,24 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; /// let notify = Arc::new(Notify::new()); /// let notify2 = notify.clone(); /// -/// tokio::spawn(async move { +/// let handle = tokio::spawn(async move { /// notify2.notified().await; /// println!("received notification"); /// }); /// /// println!("sending notification"); /// notify.notify_one(); +/// +/// // Wait for task to receive notification. +/// handle.await.unwrap(); /// } /// ``` /// -/// Unbound mpsc channel. +/// Unbound multi-producer single-consumer (mpsc) channel. +/// +/// No wakeups can be lost when using this channel because the call to +/// `notify_one()` will store a permit in the `Notify`, which the following call +/// to `notified()` will consume. /// /// ``` /// use tokio::sync::Notify; @@ -88,6 +99,8 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; /// self.notify.notify_one(); /// } /// +/// // This is a single-consumer channel, so several concurrent calls to +/// // `recv` are not allowed. /// pub async fn recv(&self) -> T { /// loop { /// // Drain values @@ -102,45 +115,250 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; /// } /// ``` /// +/// Unbound multi-producer multi-consumer (mpmc) channel. +/// +/// The call to [`enable`] is important because otherwise if you have two +/// calls to `recv` and two calls to `send` in parallel, the following could +/// happen: +/// +/// 1. Both calls to `try_recv` return `None`. +/// 2. Both new elements are added to the vector. +/// 3. The `notify_one` method is called twice, adding only a single +/// permit to the `Notify`. +/// 4. Both calls to `recv` reach the `Notified` future. One of them +/// consumes the permit, and the other sleeps forever. +/// +/// By adding the `Notified` futures to the list by calling `enable` before +/// `try_recv`, the `notify_one` calls in step three would remove the +/// futures from the list and mark them notified instead of adding a permit +/// to the `Notify`. This ensures that both futures are woken. +/// +/// Notice that this failure can only happen if there are two concurrent calls +/// to `recv`. This is why the mpsc example above does not require a call to +/// `enable`. +/// +/// ``` +/// use tokio::sync::Notify; +/// +/// use std::collections::VecDeque; +/// use std::sync::Mutex; +/// +/// struct Channel<T> { +/// messages: Mutex<VecDeque<T>>, +/// notify_on_sent: Notify, +/// } +/// +/// impl<T> Channel<T> { +/// pub fn send(&self, msg: T) { +/// let mut locked_queue = self.messages.lock().unwrap(); +/// locked_queue.push_back(msg); +/// drop(locked_queue); +/// +/// // Send a notification to one of the calls currently +/// // waiting in a call to `recv`. +/// self.notify_on_sent.notify_one(); +/// } +/// +/// pub fn try_recv(&self) -> Option<T> { +/// let mut locked_queue = self.messages.lock().unwrap(); +/// locked_queue.pop_front() +/// } +/// +/// pub async fn recv(&self) -> T { +/// let future = self.notify_on_sent.notified(); +/// tokio::pin!(future); +/// +/// loop { +/// // Make sure that no wakeup is lost if we get +/// // `None` from `try_recv`. +/// future.as_mut().enable(); +/// +/// if let Some(msg) = self.try_recv() { +/// return msg; +/// } +/// +/// // Wait for a call to `notify_one`. +/// // +/// // This uses `.as_mut()` to avoid consuming the future, +/// // which lets us call `Pin::set` below. +/// future.as_mut().await; +/// +/// // Reset the future in case another call to +/// // `try_recv` got the message before us. +/// future.set(self.notify_on_sent.notified()); +/// } +/// } +/// } +/// ``` +/// /// [park]: std::thread::park /// [unpark]: std::thread::Thread::unpark /// [`notified().await`]: Notify::notified() /// [`notify_one()`]: Notify::notify_one() +/// [`enable`]: Notified::enable() /// [`Semaphore`]: crate::sync::Semaphore #[derive(Debug)] pub struct Notify { - // This uses 2 bits to store one of `EMPTY`, + // `state` uses 2 bits to store one of `EMPTY`, // `WAITING` or `NOTIFIED`. The rest of the bits // are used to store the number of times `notify_waiters` // was called. + // + // Throughout the code there are two assumptions: + // - state can be transitioned *from* `WAITING` only if + // `waiters` lock is held + // - number of times `notify_waiters` was called can + // be modified only if `waiters` lock is held state: AtomicUsize, waiters: Mutex<WaitList>, } -#[derive(Debug, Clone, Copy)] -enum NotificationType { - // Notification triggered by calling `notify_waiters` - AllWaiters, - // Notification triggered by calling `notify_one` - OneWaiter, -} - #[derive(Debug)] struct Waiter { - /// Intrusive linked-list pointers + /// Intrusive linked-list pointers. pointers: linked_list::Pointers<Waiter>, - /// Waiting task's waker - waker: Option<Waker>, + /// Waiting task's waker. Depending on the value of `notification`, + /// this field is either protected by the `waiters` lock in + /// `Notify`, or it is exclusively owned by the enclosing `Waiter`. + waker: UnsafeCell<Option<Waker>>, - /// `true` if the notification has been assigned to this waiter. - notified: Option<NotificationType>, + /// Notification for this waiter. + /// * if it's `None`, then `waker` is protected by the `waiters` lock. + /// * if it's `Some`, then `waker` is exclusively owned by the + /// enclosing `Waiter` and can be accessed without locking. + notification: AtomicNotification, /// Should not be `Unpin`. _p: PhantomPinned, } -/// Future returned from [`Notify::notified()`] +impl Waiter { + fn new() -> Waiter { + Waiter { + pointers: linked_list::Pointers::new(), + waker: UnsafeCell::new(None), + notification: AtomicNotification::none(), + _p: PhantomPinned, + } + } +} + +generate_addr_of_methods! { + impl<> Waiter { + unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { + &self.pointers + } + } +} + +// No notification. +const NOTIFICATION_NONE: usize = 0; + +// Notification type used by `notify_one`. +const NOTIFICATION_ONE: usize = 1; + +// Notification type used by `notify_waiters`. +const NOTIFICATION_ALL: usize = 2; + +/// Notification for a `Waiter`. +/// This struct is equivalent to `Option<Notification>`, but uses +/// `AtomicUsize` inside for atomic operations. +#[derive(Debug)] +struct AtomicNotification(AtomicUsize); + +impl AtomicNotification { + fn none() -> Self { + AtomicNotification(AtomicUsize::new(NOTIFICATION_NONE)) + } + + /// Store-release a notification. + /// This method should be called exactly once. + fn store_release(&self, notification: Notification) { + self.0.store(notification as usize, Release); + } + + fn load(&self, ordering: Ordering) -> Option<Notification> { + match self.0.load(ordering) { + NOTIFICATION_NONE => None, + NOTIFICATION_ONE => Some(Notification::One), + NOTIFICATION_ALL => Some(Notification::All), + _ => unreachable!(), + } + } + + /// Clears the notification. + /// This method is used by a `Notified` future to consume the + /// notification. It uses relaxed ordering and should be only + /// used once the atomic notification is no longer shared. + fn clear(&self) { + self.0.store(NOTIFICATION_NONE, Relaxed); + } +} + +#[derive(Debug, PartialEq, Eq)] +#[repr(usize)] +enum Notification { + One = NOTIFICATION_ONE, + All = NOTIFICATION_ALL, +} + +/// List used in `Notify::notify_waiters`. It wraps a guarded linked list +/// and gates the access to it on `notify.waiters` mutex. It also empties +/// the list on drop. +struct NotifyWaitersList<'a> { + list: GuardedWaitList, + is_empty: bool, + notify: &'a Notify, +} + +impl<'a> NotifyWaitersList<'a> { + fn new( + unguarded_list: WaitList, + guard: Pin<&'a Waiter>, + notify: &'a Notify, + ) -> NotifyWaitersList<'a> { + let guard_ptr = NonNull::from(guard.get_ref()); + let list = unguarded_list.into_guarded(guard_ptr); + NotifyWaitersList { + list, + is_empty: false, + notify, + } + } + + /// Removes the last element from the guarded list. Modifying this list + /// requires an exclusive access to the main list in `Notify`. + fn pop_back_locked(&mut self, _waiters: &mut WaitList) -> Option<NonNull<Waiter>> { + let result = self.list.pop_back(); + if result.is_none() { + // Save information about emptiness to avoid waiting for lock + // in the destructor. + self.is_empty = true; + } + result + } +} + +impl Drop for NotifyWaitersList<'_> { + fn drop(&mut self) { + // If the list is not empty, we unlink all waiters from it. + // We do not wake the waiters to avoid double panics. + if !self.is_empty { + let _lock_guard = self.notify.waiters.lock(); + while let Some(waiter) = self.list.pop_back() { + // Safety: we never make mutable references to waiters. + let waiter = unsafe { waiter.as_ref() }; + waiter.notification.store_release(Notification::All); + } + } + } +} + +/// Future returned from [`Notify::notified()`]. +/// +/// This future is fused, so once it has completed, any future calls to poll +/// will immediately return `Poll::Ready`. #[derive(Debug)] pub struct Notified<'a> { /// The `Notify` being received on. @@ -149,8 +367,11 @@ pub struct Notified<'a> { /// The current state of the receiving process. state: State, + /// Number of calls to `notify_waiters` at the time of creation. + notify_waiters_calls: usize, + /// Entry in the waiter `LinkedList`. - waiter: UnsafeCell<Waiter>, + waiter: Waiter, } unsafe impl<'a> Send for Notified<'a> {} @@ -158,7 +379,7 @@ unsafe impl<'a> Sync for Notified<'a> {} #[derive(Debug)] enum State { - Init(usize), + Init, Waiting, Done, } @@ -167,13 +388,13 @@ const NOTIFY_WAITERS_SHIFT: usize = 2; const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; -/// Initial "idle" state +/// Initial "idle" state. const EMPTY: usize = 0; /// One or more threads are currently waiting to be notified. const WAITING: usize = 1; -/// Pending notification +/// Pending notification. const NOTIFIED: usize = 2; fn set_state(data: usize, state: usize) -> usize { @@ -244,7 +465,16 @@ impl Notify { /// immediately, consuming that permit. Otherwise, `notified().await` waits /// for a permit to be made available by the next call to `notify_one()`. /// + /// The `Notified` future is not guaranteed to receive wakeups from calls to + /// `notify_one()` if it has not yet been polled. See the documentation for + /// [`Notified::enable()`] for more details. + /// + /// The `Notified` future is guaranteed to receive wakeups from + /// `notify_waiters()` as soon as it has been created, even if it has not + /// yet been polled. + /// /// [`notify_one()`]: Notify::notify_one + /// [`Notified::enable()`]: Notified::enable /// /// # Cancel safety /// @@ -274,21 +504,17 @@ impl Notify { /// ``` pub fn notified(&self) -> Notified<'_> { // we load the number of times notify_waiters - // was called and store that in our initial state + // was called and store that in the future. let state = self.state.load(SeqCst); Notified { notify: self, - state: State::Init(state >> NOTIFY_WAITERS_SHIFT), - waiter: UnsafeCell::new(Waiter { - pointers: linked_list::Pointers::new(), - waker: None, - notified: None, - _p: PhantomPinned, - }), + state: State::Init, + notify_waiters_calls: get_num_notify_waiters_calls(state), + waiter: Waiter::new(), } } - /// Notifies a waiting task + /// Notifies a waiting task. /// /// If a task is currently waiting, that task is notified. Otherwise, a /// permit is stored in this `Notify` value and the **next** call to @@ -358,7 +584,7 @@ impl Notify { } } - /// Notifies all waiting tasks + /// Notifies all waiting tasks. /// /// If a task is currently waiting, that task is notified. Unlike with /// `notify_one()`, no permit is stored to be used by the next call to @@ -391,43 +617,56 @@ impl Notify { /// } /// ``` pub fn notify_waiters(&self) { - const NUM_WAKERS: usize = 32; - - let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default(); - let mut curr_waker = 0; - - // There are waiters, the lock must be acquired to notify. let mut waiters = self.waiters.lock(); - // The state must be reloaded while the lock is held. The state may only + // The state must be loaded while the lock is held. The state may only // transition out of WAITING while the lock is held. let curr = self.state.load(SeqCst); - if let EMPTY | NOTIFIED = get_state(curr) { + if matches!(get_state(curr), EMPTY | NOTIFIED) { // There are no waiting tasks. All we need to do is increment the // number of times this method was called. atomic_inc_num_notify_waiters_calls(&self.state); return; } - // At this point, it is guaranteed that the state will not - // concurrently change, as holding the lock is required to - // transition **out** of `WAITING`. + // Increment the number of times this method was called + // and transition to empty. + let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY); + self.state.store(new_state, SeqCst); + + // It is critical for `GuardedLinkedList` safety that the guard node is + // pinned in memory and is not dropped until the guarded list is dropped. + let guard = Waiter::new(); + pin!(guard); + + // We move all waiters to a secondary list. It uses a `GuardedLinkedList` + // underneath to allow every waiter to safely remove itself from it. + // + // * This list will be still guarded by the `waiters` lock. + // `NotifyWaitersList` wrapper makes sure we hold the lock to modify it. + // * This wrapper will empty the list on drop. It is critical for safety + // that we will not leave any list entry with a pointer to the local + // guard node after this function returns / panics. + let mut list = NotifyWaitersList::new(std::mem::take(&mut *waiters), guard.as_ref(), self); + + let mut wakers = WakeList::new(); 'outer: loop { - while curr_waker < NUM_WAKERS { - match waiters.pop_back() { - Some(mut waiter) => { - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; - - assert!(waiter.notified.is_none()); - - waiter.notified = Some(NotificationType::AllWaiters); - - if let Some(waker) = waiter.waker.take() { - wakers[curr_waker] = Some(waker); - curr_waker += 1; + while wakers.can_push() { + match list.pop_back_locked(&mut waiters) { + Some(waiter) => { + // Safety: we never make mutable references to waiters. + let waiter = unsafe { waiter.as_ref() }; + + // Safety: we hold the lock, so we can access the waker. + if let Some(waker) = + unsafe { waiter.waker.with_mut(|waker| (*waker).take()) } + { + wakers.push(waker); } + + // This waiter is unlinked and will not be shared ever again, release it. + waiter.notification.store_release(Notification::All); } None => { break 'outer; @@ -435,30 +674,21 @@ impl Notify { } } + // Release the lock before notifying. drop(waiters); - for waker in wakers.iter_mut().take(curr_waker) { - waker.take().unwrap().wake(); - } - - curr_waker = 0; + // One of the wakers may panic, but the remaining waiters will still + // be unlinked from the list in `NotifyWaitersList` destructor. + wakers.wake_all(); // Acquire the lock again. waiters = self.waiters.lock(); } - // All waiters will be notified, the state must be transitioned to - // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be - // held, a `store` is sufficient. - let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY); - self.state.store(new, SeqCst); - // Release the lock before notifying drop(waiters); - for waker in wakers.iter_mut().take(curr_waker) { - waker.take().unwrap().wake(); - } + wakers.wake_all(); } } @@ -468,6 +698,9 @@ impl Default for Notify { } } +impl UnwindSafe for Notify {} +impl RefUnwindSafe for Notify {} + fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> { loop { match get_state(curr) { @@ -490,15 +723,16 @@ fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Op // transition **out** of `WAITING`. // // Get a pending waiter - let mut waiter = waiters.pop_back().unwrap(); + let waiter = waiters.pop_back().unwrap(); - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; + // Safety: we never make mutable references to waiters. + let waiter = unsafe { waiter.as_ref() }; - assert!(waiter.notified.is_none()); + // Safety: we hold the lock, so we can access the waker. + let waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) }; - waiter.notified = Some(NotificationType::OneWaiter); - let waker = waiter.waker.take(); + // This waiter is unlinked and will not be shared ever again, release it. + waiter.notification.store_release(Notification::One); if waiters.is_empty() { // As this the **final** waiter in the list, the state @@ -518,32 +752,142 @@ fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Op // ===== impl Notified ===== impl Notified<'_> { + /// Adds this future to the list of futures that are ready to receive + /// wakeups from calls to [`notify_one`]. + /// + /// Polling the future also adds it to the list, so this method should only + /// be used if you want to add the future to the list before the first call + /// to `poll`. (In fact, this method is equivalent to calling `poll` except + /// that no `Waker` is registered.) + /// + /// This has no effect on notifications sent using [`notify_waiters`], which + /// are received as long as they happen after the creation of the `Notified` + /// regardless of whether `enable` or `poll` has been called. + /// + /// This method returns true if the `Notified` is ready. This happens in the + /// following situations: + /// + /// 1. The `notify_waiters` method was called between the creation of the + /// `Notified` and the call to this method. + /// 2. This is the first call to `enable` or `poll` on this future, and the + /// `Notify` was holding a permit from a previous call to `notify_one`. + /// The call consumes the permit in that case. + /// 3. The future has previously been enabled or polled, and it has since + /// then been marked ready by either consuming a permit from the + /// `Notify`, or by a call to `notify_one` or `notify_waiters` that + /// removed it from the list of futures ready to receive wakeups. + /// + /// If this method returns true, any future calls to poll on the same future + /// will immediately return `Poll::Ready`. + /// + /// # Examples + /// + /// Unbound multi-producer multi-consumer (mpmc) channel. + /// + /// The call to `enable` is important because otherwise if you have two + /// calls to `recv` and two calls to `send` in parallel, the following could + /// happen: + /// + /// 1. Both calls to `try_recv` return `None`. + /// 2. Both new elements are added to the vector. + /// 3. The `notify_one` method is called twice, adding only a single + /// permit to the `Notify`. + /// 4. Both calls to `recv` reach the `Notified` future. One of them + /// consumes the permit, and the other sleeps forever. + /// + /// By adding the `Notified` futures to the list by calling `enable` before + /// `try_recv`, the `notify_one` calls in step three would remove the + /// futures from the list and mark them notified instead of adding a permit + /// to the `Notify`. This ensures that both futures are woken. + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// use std::collections::VecDeque; + /// use std::sync::Mutex; + /// + /// struct Channel<T> { + /// messages: Mutex<VecDeque<T>>, + /// notify_on_sent: Notify, + /// } + /// + /// impl<T> Channel<T> { + /// pub fn send(&self, msg: T) { + /// let mut locked_queue = self.messages.lock().unwrap(); + /// locked_queue.push_back(msg); + /// drop(locked_queue); + /// + /// // Send a notification to one of the calls currently + /// // waiting in a call to `recv`. + /// self.notify_on_sent.notify_one(); + /// } + /// + /// pub fn try_recv(&self) -> Option<T> { + /// let mut locked_queue = self.messages.lock().unwrap(); + /// locked_queue.pop_front() + /// } + /// + /// pub async fn recv(&self) -> T { + /// let future = self.notify_on_sent.notified(); + /// tokio::pin!(future); + /// + /// loop { + /// // Make sure that no wakeup is lost if we get + /// // `None` from `try_recv`. + /// future.as_mut().enable(); + /// + /// if let Some(msg) = self.try_recv() { + /// return msg; + /// } + /// + /// // Wait for a call to `notify_one`. + /// // + /// // This uses `.as_mut()` to avoid consuming the future, + /// // which lets us call `Pin::set` below. + /// future.as_mut().await; + /// + /// // Reset the future in case another call to + /// // `try_recv` got the message before us. + /// future.set(self.notify_on_sent.notified()); + /// } + /// } + /// } + /// ``` + /// + /// [`notify_one`]: Notify::notify_one() + /// [`notify_waiters`]: Notify::notify_waiters() + pub fn enable(self: Pin<&mut Self>) -> bool { + self.poll_notified(None).is_ready() + } + /// A custom `project` implementation is used in place of `pin-project-lite` /// as a custom drop implementation is needed. - fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) { + fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &Waiter) { unsafe { - // Safety: both `notify` and `state` are `Unpin`. + // Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`. is_unpin::<&Notify>(); - is_unpin::<AtomicUsize>(); + is_unpin::<State>(); + is_unpin::<usize>(); let me = self.get_unchecked_mut(); - (&me.notify, &mut me.state, &me.waiter) + ( + me.notify, + &mut me.state, + &me.notify_waiters_calls, + &me.waiter, + ) } } -} - -impl Future for Notified<'_> { - type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> { use State::*; - let (notify, state, waiter) = self.project(); + let (notify, state, notify_waiters_calls, waiter) = self.project(); - loop { + 'outer_loop: loop { match *state { - Init(initial_notify_waiters_calls) => { + Init => { let curr = notify.state.load(SeqCst); // Optimistically try acquiring a pending notification @@ -557,9 +901,13 @@ impl Future for Notified<'_> { if res.is_ok() { // Acquired the notification *state = Done; - return Poll::Ready(()); + continue 'outer_loop; } + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = waker.cloned(); + // Acquire the lock and attempt to transition to the waiting // state. let mut waiters = notify.waiters.lock(); @@ -569,9 +917,9 @@ impl Future for Notified<'_> { // if notify_waiters has been called after the future // was created, then we are done - if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls { + if get_num_notify_waiters_calls(curr) != *notify_waiters_calls { *state = Done; - return Poll::Ready(()); + continue 'outer_loop; } // Transition the state to WAITING. @@ -607,7 +955,7 @@ impl Future for Notified<'_> { Ok(_) => { // Acquired the notification *state = Done; - return Poll::Ready(()); + continue 'outer_loop; } Err(actual) => { assert_eq!(get_state(actual), EMPTY); @@ -619,44 +967,109 @@ impl Future for Notified<'_> { } } - // Safety: called while locked. - unsafe { - (*waiter.get()).waker = Some(cx.waker().clone()); + let mut old_waker = None; + if waker.is_some() { + // Safety: called while locked. + // + // The use of `old_waiter` here is not necessary, as the field is always + // None when we reach this line. + unsafe { + old_waker = + waiter.waker.with_mut(|v| std::mem::replace(&mut *v, waker)); + } } // Insert the waiter into the linked list - // - // safety: pointers from `UnsafeCell` are never null. - waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + waiters.push_front(NonNull::from(waiter)); *state = Waiting; + drop(waiters); + drop(old_waker); + return Poll::Pending; } Waiting => { - // Currently in the "Waiting" state, implying the caller has - // a waiter stored in the waiter list (guarded by - // `notify.waiters`). In order to access the waker fields, - // we must hold the lock. + #[cfg(tokio_taskdump)] + if let Some(waker) = waker { + let mut ctx = Context::from_waker(waker); + ready!(crate::trace::trace_leaf(&mut ctx)); + } + + if waiter.notification.load(Acquire).is_some() { + // Safety: waiter is already unlinked and will not be shared again, + // so we have an exclusive access to `waker`. + drop(unsafe { waiter.waker.with_mut(|waker| (*waker).take()) }); + + waiter.notification.clear(); + *state = Done; + return Poll::Ready(()); + } + + // Our waiter was not notified, implying it is still stored in a waiter + // list (guarded by `notify.waiters`). In order to access the waker + // fields, we must acquire the lock. + + let mut old_waker = None; + let mut waiters = notify.waiters.lock(); + + // We hold the lock and notifications are set only with the lock held, + // so this can be relaxed, because the happens-before relationship is + // established through the mutex. + if waiter.notification.load(Relaxed).is_some() { + // Safety: waiter is already unlinked and will not be shared again, + // so we have an exclusive access to `waker`. + old_waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) }; + + waiter.notification.clear(); + + // Drop the old waker after releasing the lock. + drop(waiters); + drop(old_waker); - let waiters = notify.waiters.lock(); + *state = Done; + return Poll::Ready(()); + } + + // Load the state with the lock held. + let curr = notify.state.load(SeqCst); + + if get_num_notify_waiters_calls(curr) != *notify_waiters_calls { + // Before we add a waiter to the list we check if these numbers are + // different while holding the lock. If these numbers are different now, + // it means that there is a call to `notify_waiters` in progress and this + // waiter must be contained by a guarded list used in `notify_waiters`. + // We can treat the waiter as notified and remove it from the list, as + // it would have been notified in the `notify_waiters` call anyways. - // Safety: called while locked - let w = unsafe { &mut *waiter.get() }; + // Safety: we hold the lock, so we can modify the waker. + old_waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) }; - if w.notified.is_some() { - // Our waker has been notified. Reset the fields and - // remove it from the list. - w.waker = None; - w.notified = None; + // Safety: we hold the lock, so we have an exclusive access to the list. + // The list is used in `notify_waiters`, so it must be guarded. + unsafe { waiters.remove(NonNull::from(waiter)) }; *state = Done; } else { - // Update the waker, if necessary. - if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { - w.waker = Some(cx.waker().clone()); + // Safety: we hold the lock, so we can modify the waker. + unsafe { + waiter.waker.with_mut(|v| { + if let Some(waker) = waker { + let should_update = match &*v { + Some(current_waker) => !current_waker.will_wake(waker), + None => true, + }; + if should_update { + old_waker = std::mem::replace(&mut *v, Some(waker.clone())); + } + } + }); } + // Drop the old waker after releasing the lock. + drop(waiters); + drop(old_waker); + return Poll::Pending; } @@ -666,8 +1079,16 @@ impl Future for Notified<'_> { // is helpful to visualize the scope of the critical // section. drop(waiters); + + // Drop the old waker after releasing the lock. + drop(old_waker); } Done => { + #[cfg(tokio_taskdump)] + if let Some(waker) = waker { + let mut ctx = Context::from_waker(waker); + ready!(crate::trace::trace_leaf(&mut ctx)); + } return Poll::Ready(()); } } @@ -675,40 +1096,48 @@ impl Future for Notified<'_> { } } +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.poll_notified(Some(cx.waker())) + } +} + impl Drop for Notified<'_> { fn drop(&mut self) { use State::*; // Safety: The type only transitions to a "Waiting" state when pinned. - let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() }; + let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() }; // This is where we ensure safety. The `Notified` value is being // dropped, which means we must ensure that the waiter entry is no // longer stored in the linked list. - if let Waiting = *state { + if matches!(*state, Waiting) { let mut waiters = notify.waiters.lock(); let mut notify_state = notify.state.load(SeqCst); + // We hold the lock, so this field is not concurrently accessed by + // `notify_*` functions and we can use the relaxed ordering. + let notification = waiter.notification.load(Relaxed); + // remove the entry from the list (if not already removed) // - // safety: the waiter is only added to `waiters` by virtue of it - // being the only `LinkedList` available to the type. - unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; - - if waiters.is_empty() { - if let WAITING = get_state(notify_state) { - notify_state = set_state(notify_state, EMPTY); - notify.state.store(notify_state, SeqCst); - } + // Safety: we hold the lock, so we have an exclusive access to every list the + // waiter may be contained in. If the node is not contained in the `waiters` + // list, then it is contained by a guarded list used by `notify_waiters`. + unsafe { waiters.remove(NonNull::from(waiter)) }; + + if waiters.is_empty() && get_state(notify_state) == WAITING { + notify_state = set_state(notify_state, EMPTY); + notify.state.store(notify_state, SeqCst); } // See if the node was notified but not received. In this case, if // the notification was triggered via `notify_one`, it must be sent // to the next waiter. - // - // Safety: with the entry removed from the linked list, there can be - // no concurrent access to the entry - if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } { + if notification == Some(Notification::One) { if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { drop(waiters); waker.wake(); @@ -733,8 +1162,8 @@ unsafe impl linked_list::Link for Waiter { ptr } - unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { - NonNull::from(&mut target.as_mut().pointers) + unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + Waiter::addr_of_pointers(target) } } diff --git a/vendor/tokio/src/sync/once_cell.rs b/vendor/tokio/src/sync/once_cell.rs index ce55d9e35..90ea5cd68 100644 --- a/vendor/tokio/src/sync/once_cell.rs +++ b/vendor/tokio/src/sync/once_cell.rs @@ -1,4 +1,4 @@ -use super::Semaphore; +use super::{Semaphore, SemaphorePermit, TryAcquireError}; use crate::loom::cell::UnsafeCell; use std::error::Error; use std::fmt; @@ -8,15 +8,30 @@ use std::ops::Drop; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; -/// A thread-safe cell which can be written to only once. +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. /// -/// Provides the functionality to either set the value, in case `OnceCell` -/// is uninitialized, or get the already initialized value by using an async -/// function via [`OnceCell::get_or_init`]. -/// -/// [`OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init +/// A `OnceCell` is typically used for global variables that need to be +/// initialized once on first use, but need no further changes. The `OnceCell` +/// in Tokio allows the initialization procedure to be asynchronous. /// /// # Examples +/// /// ``` /// use tokio::sync::OnceCell; /// @@ -28,8 +43,28 @@ use std::sync::atomic::{AtomicBool, Ordering}; /// /// #[tokio::main] /// async fn main() { -/// let result1 = ONCE.get_or_init(some_computation).await; -/// assert_eq!(*result1, 2); +/// let result = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result, 2); +/// } +/// ``` +/// +/// It is often useful to write a wrapper method for accessing the value. +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// async fn get_global_integer() -> &'static u32 { +/// ONCE.get_or_init(|| async { +/// 1 + 1 +/// }).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = get_global_integer().await; +/// assert_eq!(*result, 2); /// } /// ``` pub struct OnceCell<T> { @@ -68,10 +103,10 @@ impl<T: Eq> Eq for OnceCell<T> {} impl<T> Drop for OnceCell<T> { fn drop(&mut self) { - if self.initialized() { + if self.initialized_mut() { unsafe { self.value - .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); + .with_mut(|ptr| ptr::drop_in_place((*ptr).as_mut_ptr())); }; } } @@ -90,7 +125,7 @@ impl<T> From<T> for OnceCell<T> { } impl<T> OnceCell<T> { - /// Creates a new uninitialized OnceCell instance. + /// Creates a new empty `OnceCell` instance. pub fn new() -> Self { OnceCell { value_set: AtomicBool::new(false), @@ -99,8 +134,9 @@ impl<T> OnceCell<T> { } } - /// Creates a new initialized OnceCell instance if `value` is `Some`, otherwise - /// has the same functionality as [`OnceCell::new`]. + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. /// /// [`OnceCell::new`]: crate::sync::OnceCell::new pub fn new_with(value: Option<T>) -> Self { @@ -111,8 +147,31 @@ impl<T> OnceCell<T> { } } - /// Creates a new uninitialized OnceCell instance. - #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + /// Creates a new empty `OnceCell` instance. + /// + /// Equivalent to `OnceCell::new`, except that it can be used in static + /// variables. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::OnceCell; + /// + /// static ONCE: OnceCell<u32> = OnceCell::const_new(); + /// + /// async fn get_global_integer() -> &'static u32 { + /// ONCE.get_or_init(|| async { + /// 1 + 1 + /// }).await + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let result = get_global_integer().await; + /// assert_eq!(*result, 2); + /// } + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new() -> Self { OnceCell { @@ -122,33 +181,48 @@ impl<T> OnceCell<T> { } } - /// Whether the value of the OnceCell is set or not. + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. self.value_set.load(Ordering::Acquire) } - // SAFETY: safe to call only once self.initialized() is true + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked(&self) -> &T { &*self.value.with(|ptr| (*ptr).as_ptr()) } - // SAFETY: safe to call only once self.initialized() is true. Safe because - // because of the mutable reference. + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked_mut(&mut self) -> &mut T { &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) } - // SAFETY: safe to call only once a permit on the semaphore has been - // acquired - unsafe fn set_value(&self, value: T) { - self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. self.value_set.store(true, Ordering::Release); self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } } - /// Tries to get a reference to the value of the OnceCell. - /// - /// Returns None if the value of the OnceCell hasn't previously been initialized. + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. pub fn get(&self) -> Option<&T> { if self.initialized() { Some(unsafe { self.get_unchecked() }) @@ -157,179 +231,165 @@ impl<T> OnceCell<T> { } } - /// Tries to return a mutable reference to the value of the cell. + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. /// - /// Returns None if the cell hasn't previously been initialized. + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. pub fn get_mut(&mut self) -> Option<&mut T> { - if self.initialized() { + if self.initialized_mut() { Some(unsafe { self.get_unchecked_mut() }) } else { None } } - /// Sets the value of the OnceCell to the argument value. + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. /// - /// If the value of the OnceCell was already set prior to this call - /// then [`SetError::AlreadyInitializedError`] is returned. If another thread - /// is initializing the cell while this method is called, - /// [`SetError::InitializingError`] is returned. In order to wait - /// for an ongoing initialization to finish, call - /// [`OnceCell::get_or_init`] instead. + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. /// /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError - /// ['OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init pub fn set(&self, value: T) -> Result<(), SetError<T>> { - if !self.initialized() { - // Another thread might be initializing the cell, in which case `try_acquire` will - // return an error - match self.semaphore.try_acquire() { - Ok(_permit) => { - if !self.initialized() { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - return Ok(()); - } else { - unreachable!( - "acquired the permit after OnceCell value was already initialized." - ); - } - } - _ => { - // Couldn't acquire the permit, look if initializing process is already completed - if !self.initialized() { - return Err(SetError::InitializingError(value)); - } - } - } + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); } - Err(SetError::AlreadyInitializedError(value)) + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_init<F, Fut>(&self, f: F) -> &T where F: FnOnce() -> Fut, Fut: Future<Output = T>, { + crate::trace::async_trace_leaf().await; + if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { self.get_unchecked() } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has initialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!("acquired semaphore after value was already initialized."); - } + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } } } } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. - /// If the function argument `f` returns an error, `get_or_try_init` - /// returns that error, otherwise the result of `f` will be stored in the cell. + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E> where F: FnOnce() -> Fut, Fut: Future<Output = Result<T, E>>, { + crate::trace::async_trace_leaf().await; + if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { Ok(self.get_unchecked()) } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has initialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_try_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - match value { - Ok(value) => { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } - Err(e) => Err(e), - } - } else { - unreachable!("acquired semaphore after value was already initialized."); + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), } } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } } } } } - /// Moves the value out of the cell, destroying the cell in the process. - /// - /// Returns `None` if the cell is uninitialized. + /// Takes the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. pub fn into_inner(mut self) -> Option<T> { - if self.initialized() { + if self.initialized_mut() { // Set to uninitialized for the destructor of `OnceCell` to work properly *self.value_set.get_mut() = false; Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) @@ -338,20 +398,18 @@ impl<T> OnceCell<T> { } } - /// Takes ownership of the current value, leaving the cell uninitialized. - /// - /// Returns `None` if the cell is uninitialized. + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. pub fn take(&mut self) -> Option<T> { std::mem::take(self).into_inner() } } -// Since `get` gives us access to immutable references of the -// OnceCell, OnceCell can only be Sync if T is Sync, otherwise -// OnceCell would allow sharing references of !Sync values across -// threads. We need T to be Send in order for OnceCell to by Sync -// because we can use `set` on `&OnceCell<T>` to send -// values (of type T) across threads. +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell<T>` to send values +// (of type T) across threads. unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // Access to OnceCell's value is guarded by the semaphore permit @@ -359,20 +417,17 @@ unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // it's safe to send it to another thread unsafe impl<T: Send> Send for OnceCell<T> {} -/// Errors that can be returned from [`OnceCell::set`] +/// Errors that can be returned from [`OnceCell::set`]. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] pub enum SetError<T> { - /// Error resulting from [`OnceCell::set`] calls if the cell was previously initialized. + /// The cell was already initialized when [`OnceCell::set`] was called. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set AlreadyInitializedError(T), - /// Error resulting from [`OnceCell::set`] calls when the cell is currently being - /// initialized during the calls to that method. - /// - /// [`OnceCell::set`]: crate::sync::OnceCell::set + /// The cell is currently being initialized. InitializingError(T), } diff --git a/vendor/tokio/src/sync/oneshot.rs b/vendor/tokio/src/sync/oneshot.rs index cb4649d86..af3cc854f 100644 --- a/vendor/tokio/src/sync/oneshot.rs +++ b/vendor/tokio/src/sync/oneshot.rs @@ -9,6 +9,13 @@ //! //! Each handle can be used on separate tasks. //! +//! Since the `send` method is not async, it can be used anywhere. This includes +//! sending between two runtimes, and using it from non-async code. +//! +//! If the [`Receiver`] is closed before receiving a message which has already +//! been sent, the message will remain in the channel until the receiver is +//! dropped, at which point the message will be dropped immediately. +//! //! # Examples //! //! ``` @@ -51,10 +58,76 @@ //! } //! } //! ``` +//! +//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of +//! the channel. +//! +//! ``` +//! use tokio::sync::oneshot; +//! use tokio::time::{interval, sleep, Duration}; +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread", start_paused = true)] +//! async fn main() { +//! let (send, mut recv) = oneshot::channel(); +//! let mut interval = interval(Duration::from_millis(100)); +//! +//! # let handle = +//! tokio::spawn(async move { +//! sleep(Duration::from_secs(1)).await; +//! send.send("shut down").unwrap(); +//! }); +//! +//! loop { +//! tokio::select! { +//! _ = interval.tick() => println!("Another 100ms"), +//! msg = &mut recv => { +//! println!("Got message: {}", msg.unwrap()); +//! break; +//! } +//! } +//! } +//! # handle.await.unwrap(); +//! } +//! ``` +//! +//! To use a `Sender` from a destructor, put it in an [`Option`] and call +//! [`Option::take`]. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! struct SendOnDrop { +//! sender: Option<oneshot::Sender<&'static str>>, +//! } +//! impl Drop for SendOnDrop { +//! fn drop(&mut self) { +//! if let Some(sender) = self.sender.take() { +//! // Using `let _ =` to ignore send errors. +//! let _ = sender.send("I got dropped!"); +//! } +//! } +//! } +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread")] +//! async fn main() { +//! let (send, recv) = oneshot::channel(); +//! +//! let send_on_drop = SendOnDrop { sender: Some(send) }; +//! drop(send_on_drop); +//! +//! assert_eq!(recv.await, Ok("I got dropped!")); +//! } +//! ``` use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; use std::fmt; use std::future::Future; @@ -68,16 +141,108 @@ use std::task::{Context, Poll, Waker}; /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Sender` from a destructor, put it in an [`Option`] and call +/// [`Option::take`]. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// struct SendOnDrop { +/// sender: Option<oneshot::Sender<&'static str>>, +/// } +/// impl Drop for SendOnDrop { +/// fn drop(&mut self) { +/// if let Some(sender) = self.sender.take() { +/// // Using `let _ =` to ignore send errors. +/// let _ = sender.send("I got dropped!"); +/// } +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let (send, recv) = oneshot::channel(); +/// +/// let send_on_drop = SendOnDrop { sender: Some(send) }; +/// drop(send_on_drop); +/// +/// assert_eq!(recv.await, Ok("I got dropped!")); +/// } +/// ``` +/// +/// [`Option`]: std::option::Option +/// [`Option::take`]: std::option::Option::take #[derive(Debug)] pub struct Sender<T> { inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, } -/// Receive a value from the associated [`Sender`]. +/// Receives a value from the associated [`Sender`]. /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. /// +/// This channel has no `recv` method because the receiver itself implements the +/// [`Future`] trait. To receive a `Result<T, `[`error::RecvError`]`>`, `.await` the `Receiver` object directly. +/// +/// The `poll` method on the `Future` trait is allowed to spuriously return +/// `Poll::Pending` even if the message has been sent. If such a spurious +/// failure happens, then the caller will be woken when the spurious failure has +/// been resolved so that the caller can attempt to receive the message again. +/// Note that receiving such a wakeup does not guarantee that the next call will +/// succeed — it could fail with another spurious failure. (A spurious failure +/// does not mean that the message is lost. It is just delayed.) +/// +/// [`Future`]: trait@std::future::Future +/// /// # Examples /// /// ``` @@ -120,22 +285,63 @@ pub struct Sender<T> { /// } /// } /// ``` +/// +/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the +/// channel. +/// +/// ``` +/// use tokio::sync::oneshot; +/// use tokio::time::{interval, sleep, Duration}; +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let (send, mut recv) = oneshot::channel(); +/// let mut interval = interval(Duration::from_millis(100)); +/// +/// # let handle = +/// tokio::spawn(async move { +/// sleep(Duration::from_secs(1)).await; +/// send.send("shut down").unwrap(); +/// }); +/// +/// loop { +/// tokio::select! { +/// _ = interval.tick() => println!("Another 100ms"), +/// msg = &mut recv => { +/// println!("Got message: {}", msg.unwrap()); +/// break; +/// } +/// } +/// } +/// # handle.await.unwrap(); +/// } +/// ``` #[derive(Debug)] pub struct Receiver<T> { inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span: tracing::Span, } pub mod error { - //! Oneshot error types + //! Oneshot error types. use std::fmt; /// Error returned by the `Future` implementation for `Receiver`. - #[derive(Debug, Eq, PartialEq)] + /// + /// This error is returned by the receiver when the sender is dropped without sending. + #[derive(Debug, Eq, PartialEq, Clone)] pub struct RecvError(pub(super) ()); /// Error returned by the `try_recv` function on `Receiver`. - #[derive(Debug, Eq, PartialEq)] + #[derive(Debug, Eq, PartialEq, Clone)] pub enum TryRecvError { /// The send half of the channel has not yet sent a value. Empty, @@ -171,7 +377,7 @@ pub mod error { use self::error::*; struct Inner<T> { - /// Manages the state of the inner cell + /// Manages the state of the inner cell. state: AtomicUsize, /// The value. This is set by `Sender` and read by `Receiver`. The state of @@ -207,21 +413,21 @@ impl Task { F: FnOnce(&Waker) -> R, { self.0.with(|ptr| { - let waker: *const Waker = (&*ptr).as_ptr(); + let waker: *const Waker = (*ptr).as_ptr(); f(&*waker) }) } unsafe fn drop_task(&self) { self.0.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + let ptr: *mut Waker = (*ptr).as_mut_ptr(); ptr.drop_in_place(); }); } unsafe fn set_task(&self, cx: &mut Context<'_>) { self.0.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + let ptr: *mut Waker = (*ptr).as_mut_ptr(); ptr.write(cx.waker().clone()); }); } @@ -230,7 +436,7 @@ impl Task { #[derive(Clone, Copy)] struct State(usize); -/// Create a new one-shot channel for sending single values across asynchronous +/// Creates a new one-shot channel for sending single values across asynchronous /// tasks. /// /// The function returns separate "send" and "receive" handles. The `Sender` @@ -260,7 +466,56 @@ struct State(usize); /// } /// } /// ``` +#[track_caller] pub fn channel<T>() -> (Sender<T>, Receiver<T>) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Sender|Receiver", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = false, + tx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = false, + rx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = false, + value_sent.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = false, + value_received.op = "override", + ) + }); + + resource_span + }; + let inner = Arc::new(Inner { state: AtomicUsize::new(State::new().as_usize()), value: UnsafeCell::new(None), @@ -270,8 +525,27 @@ pub fn channel<T>() -> (Sender<T>, Receiver<T>) { let tx = Sender { inner: Some(inner.clone()), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_span = resource_span + .in_scope(|| tracing::trace_span!("runtime.resource.async_op", source = "Receiver::await")); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_poll_span = + async_op_span.in_scope(|| tracing::trace_span!("runtime.resource.async_op.poll")); + + let rx = Receiver { + inner: Some(inner), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span, }; - let rx = Receiver { inner: Some(inner) }; (tx, rx) } @@ -343,6 +617,15 @@ impl<T> Sender<T> { } } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = true, + value_sent.op = "override", + ) + }); + Ok(()) } @@ -416,7 +699,20 @@ impl<T> Sender<T> { pub async fn closed(&mut self) { use crate::future::poll_fn; - poll_fn(|cx| self.poll_closed(cx)).await + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let closed = trace::async_op( + || poll_fn(|cx| self.poll_closed(cx)), + resource_span, + "Sender::closed", + "poll_closed", + false, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let closed = poll_fn(|cx| self.poll_closed(cx)); + + closed.await } /// Returns `true` if the associated [`Receiver`] handle has been dropped. @@ -453,7 +749,7 @@ impl<T> Sender<T> { state.is_closed() } - /// Check whether the oneshot channel has been closed, and if not, schedules the + /// Checks whether the oneshot channel has been closed, and if not, schedules the /// `Waker` in the provided `Context` to receive a notification when the channel is /// closed. /// @@ -494,8 +790,10 @@ impl<T> Sender<T> { /// } /// ``` pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + ready!(crate::trace::trace_leaf(cx)); + // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); let inner = self.inner.as_ref().unwrap(); @@ -503,7 +801,7 @@ impl<T> Sender<T> { if state.is_closed() { coop.made_progress(); - return Poll::Ready(()); + return Ready(()); } if state.is_tx_task_set() { @@ -546,6 +844,14 @@ impl<T> Drop for Sender<T> { fn drop(&mut self) { if let Some(inner) = self.inner.as_ref() { inner.complete(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = true, + tx_dropped.op = "override", + ) + }); } } } @@ -613,6 +919,14 @@ impl<T> Receiver<T> { pub fn close(&mut self) { if let Some(inner) = self.inner.as_ref() { inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); } } @@ -625,12 +939,16 @@ impl<T> Receiver<T> { /// This function is useful to call from outside the context of an /// asynchronous task. /// + /// Note that unlike the `poll` method, the `try_recv` method cannot fail + /// spuriously. Any send or close event that happens before this call to + /// `try_recv` will be correctly returned to the caller. + /// /// # Return /// /// - `Ok(T)` if a value is pending in the channel. /// - `Err(TryRecvError::Empty)` if no value has been sent yet. /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending - /// a value. + /// a value, or if the message has already been received. /// /// # Examples /// @@ -690,7 +1008,17 @@ impl<T> Receiver<T> { // `UnsafeCell`. Therefore, it is now safe for us to access the // cell. match unsafe { inner.consume_value() } { - Some(value) => Ok(value), + Some(value) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = true, + value_received.op = "override", + ) + }); + Ok(value) + } None => Err(TryRecvError::Closed), } } else if state.is_closed() { @@ -706,12 +1034,52 @@ impl<T> Receiver<T> { self.inner = None; result } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel::<u8>(); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Ok(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(alias = "recv_blocking"))] + pub fn blocking_recv(self) -> Result<T, RecvError> { + crate::future::block_on(self) + } } impl<T> Drop for Receiver<T> { fn drop(&mut self) { if let Some(inner) = self.inner.as_ref() { inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); } } } @@ -721,8 +1089,21 @@ impl<T> Future for Receiver<T> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { // If `inner` is `None`, then `poll()` has already completed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _res_span = self.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_span = self.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_poll_span = self.async_op_poll_span.clone().entered(); + let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() { - ready!(inner.poll_recv(cx))? + #[cfg(all(tokio_unstable, feature = "tracing"))] + let res = ready!(trace_poll_op!("poll_recv", inner.poll_recv(cx)))?; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let res = ready!(inner.poll_recv(cx))?; + + res } else { panic!("called after complete"); }; @@ -751,8 +1132,9 @@ impl<T> Inner<T> { } fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + ready!(crate::trace::trace_leaf(cx)); // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); // Load the state let mut state = State::load(&self.state, Acquire); diff --git a/vendor/tokio/src/sync/rwlock.rs b/vendor/tokio/src/sync/rwlock.rs index 120bc72b8..dd4928546 100644 --- a/vendor/tokio/src/sync/rwlock.rs +++ b/vendor/tokio/src/sync/rwlock.rs @@ -1,9 +1,10 @@ use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; use crate::sync::mutex::TryLockError; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; use std::cell::UnsafeCell; use std::marker; use std::marker::PhantomData; -use std::mem::ManuallyDrop; use std::sync::Arc; pub(crate) mod owned_read_guard; @@ -84,8 +85,10 @@ const MAX_READS: u32 = 10; /// [`RwLockWriteGuard`]: struct@RwLockWriteGuard /// [`Send`]: trait@std::marker::Send /// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies -#[derive(Debug)] pub struct RwLock<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + // maximum number of concurrent readers mr: u32, @@ -197,14 +200,55 @@ impl<T: ?Sized> RwLock<T> { /// /// let lock = RwLock::new(5); /// ``` + #[track_caller] pub fn new(value: T) -> RwLock<T> where T: Sized, { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = MAX_READS, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(MAX_READS as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(MAX_READS as usize); + RwLock { mr: MAX_READS, c: UnsafeCell::new(value), - s: Semaphore::new(MAX_READS as usize), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -222,6 +266,7 @@ impl<T: ?Sized> RwLock<T> { /// # Panics /// /// Panics if `max_reads` is more than `u32::MAX >> 3`. + #[track_caller] pub fn with_max_readers(value: T, max_reads: u32) -> RwLock<T> where T: Sized, @@ -231,10 +276,52 @@ impl<T: ?Sized> RwLock<T> { "a RwLock may not be created with more than {} readers", MAX_READS ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = max_reads, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(max_reads as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(max_reads as usize); + RwLock { mr: max_reads, c: UnsafeCell::new(value), - s: Semaphore::new(max_reads as usize), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -257,6 +344,8 @@ impl<T: ?Sized> RwLock<T> { mr: MAX_READS, c: UnsafeCell::new(value), s: Semaphore::const_new(MAX_READS as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } @@ -281,6 +370,8 @@ impl<T: ?Sized> RwLock<T> { mr: max_reads, c: UnsafeCell::new(value), s: Semaphore::const_new(max_reads as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } @@ -327,19 +418,100 @@ impl<T: ?Sized> RwLock<T> { /// /// // Drop the guard after the spawned task finishes. /// drop(n); - ///} + /// } /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - self.s.acquire(1).await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() + let acquire_fut = async { + self.s.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + self.resource_span.clone(), + "RwLock::read", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) }); - RwLockReadGuard { - s: &self.s, - data: self.c.get(), - marker: marker::PhantomData, - } + + guard + } + + /// Blockingly locks this `RwLock` with shared read access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the read access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let mut write_lock = rwlock.write().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `write_lock` is released. + /// let read_lock = rwlock.blocking_read(); + /// assert_eq!(*read_lock, 0); + /// } + /// }); + /// + /// *write_lock -= 1; + /// drop(write_lock); // release the lock. + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// assert!(rwlock.try_write().is_ok()); + /// } + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + pub fn blocking_read(&self) -> RwLockReadGuard<'_, T> { + crate::future::block_on(self.read()) } /// Locks this `RwLock` with shared read access, causing the current task @@ -394,16 +566,47 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn read_owned(self: Arc<Self>) -> OwnedRwLockReadGuard<T> { - self.s.acquire(1).await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + let acquire_fut = async { + self.s.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + OwnedRwLockReadGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + data: self.c.get(), + lock: self, + _p: PhantomData, + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + resource_span, + "RwLock::read_owned", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) }); - OwnedRwLockReadGuard { - data: self.c.get(), - lock: ManuallyDrop::new(self), - _p: PhantomData, - } + + guard } /// Attempts to acquire this `RwLock` with shared read access. @@ -445,11 +648,24 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } - Ok(RwLockReadGuard { + let guard = RwLockReadGuard { s: &self.s, data: self.c.get(), marker: marker::PhantomData, - }) + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(guard) } /// Attempts to acquire this `RwLock` with shared read access. @@ -497,11 +713,24 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } - Ok(OwnedRwLockReadGuard { + let guard = OwnedRwLockReadGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), data: self.c.get(), - lock: ManuallyDrop::new(self), + lock: self, _p: PhantomData, - }) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(guard) } /// Locks this `RwLock` with exclusive write access, causing the current @@ -533,17 +762,100 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() + let acquire_fut = async { + self.s.acquire(self.mr).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + self.resource_span.clone(), + "RwLock::write", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) }); - RwLockWriteGuard { - permits_acquired: self.mr, - s: &self.s, - data: self.c.get(), - marker: marker::PhantomData, - } + + guard + } + + /// Blockingly locks this `RwLock` with exclusive write access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::{sync::RwLock}; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let read_lock = rwlock.read().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `read_lock` is released. + /// let mut write_lock = rwlock.blocking_write(); + /// *write_lock = 2; + /// } + /// }); + /// + /// assert_eq!(*read_lock, 1); + /// // Release the last outstanding read lock. + /// drop(read_lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let read_lock = rwlock.try_read().unwrap(); + /// assert_eq!(*read_lock, 2); + /// } + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + pub fn blocking_write(&self) -> RwLockWriteGuard<'_, T> { + crate::future::block_on(self.write()) } /// Locks this `RwLock` with exclusive write access, causing the current @@ -582,17 +894,48 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write_owned(self: Arc<Self>) -> OwnedRwLockWriteGuard<T> { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + let acquire_fut = async { + self.s.acquire(self.mr).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + OwnedRwLockWriteGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + permits_acquired: self.mr, + data: self.c.get(), + lock: self, + _p: PhantomData, + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let acquire_fut = trace::async_op( + move || acquire_fut, + resource_span, + "RwLock::write_owned", + "poll", + false, + ); + + #[allow(clippy::let_and_return)] // this lint triggers when disabling tracing + let guard = acquire_fut.await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) }); - OwnedRwLockWriteGuard { - permits_acquired: self.mr, - data: self.c.get(), - lock: ManuallyDrop::new(self), - _p: PhantomData, - } + + guard } /// Attempts to acquire this `RwLock` with exclusive write access. @@ -625,12 +968,25 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } - Ok(RwLockWriteGuard { + let guard = RwLockWriteGuard { permits_acquired: self.mr, s: &self.s, data: self.c.get(), marker: marker::PhantomData, - }) + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + Ok(guard) } /// Attempts to acquire this `RwLock` with exclusive write access. @@ -670,12 +1026,25 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } - Ok(OwnedRwLockWriteGuard { + let guard = OwnedRwLockWriteGuard { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), permits_acquired: self.mr, data: self.c.get(), - lock: ManuallyDrop::new(self), + lock: self, _p: PhantomData, - }) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + Ok(guard) } /// Returns a mutable reference to the underlying data. @@ -725,3 +1094,17 @@ where Self::new(T::default()) } } + +impl<T: ?Sized> std::fmt::Debug for RwLock<T> +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("RwLock"); + match self.try_read() { + Ok(inner) => d.field("data", &&*inner), + Err(_) => d.field("data", &format_args!("<locked>")), + }; + d.finish() + } +} diff --git a/vendor/tokio/src/sync/rwlock/owned_read_guard.rs b/vendor/tokio/src/sync/rwlock/owned_read_guard.rs index b7f3926a4..273e7b86f 100644 --- a/vendor/tokio/src/sync/rwlock/owned_read_guard.rs +++ b/vendor/tokio/src/sync/rwlock/owned_read_guard.rs @@ -1,10 +1,7 @@ use crate::sync::rwlock::RwLock; -use std::fmt; use std::marker::PhantomData; -use std::mem; -use std::mem::ManuallyDrop; -use std::ops; use std::sync::Arc; +use std::{fmt, mem, ops, ptr}; /// Owned RAII structure used to release the shared read access of a lock when /// dropped. @@ -14,15 +11,41 @@ use std::sync::Arc; /// /// [`read_owned`]: method@crate::sync::RwLock::read_owned /// [`RwLock`]: struct@crate::sync::RwLock +#[clippy::has_significant_drop] pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> { - // ManuallyDrop allows us to destructure into this field without running the destructor. - pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) lock: Arc<RwLock<T>>, pub(super) data: *const U, pub(super) _p: PhantomData<T>, } +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<T: ?Sized, U: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: Arc<RwLock<T>>, + data: *const U, +} + impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { - /// Make a new `OwnedRwLockReadGuard` for a component of the locked data. + fn skip_drop(self) -> Inner<T, U> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + unsafe { + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: ptr::read(&me.resource_span), + lock: ptr::read(&me.lock), + data: me.data, + } + } + } + + /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data. /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in /// already locked the data. /// @@ -50,18 +73,19 @@ impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { /// # } /// ``` #[inline] - pub fn map<F, V: ?Sized>(mut this: Self, f: F) -> OwnedRwLockReadGuard<T, V> + pub fn map<F, V: ?Sized>(this: Self, f: F) -> OwnedRwLockReadGuard<T, V> where F: FnOnce(&U) -> &V, { let data = f(&*this) as *const V; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + OwnedRwLockReadGuard { - lock: ManuallyDrop::new(lock), + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } @@ -96,7 +120,7 @@ impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { /// # } /// ``` #[inline] - pub fn try_map<F, V: ?Sized>(mut this: Self, f: F) -> Result<OwnedRwLockReadGuard<T, V>, Self> + pub fn try_map<F, V: ?Sized>(this: Self, f: F) -> Result<OwnedRwLockReadGuard<T, V>, Self> where F: FnOnce(&U) -> Option<&V>, { @@ -104,13 +128,14 @@ impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { Some(data) => data as *const V, None => return Err(this), }; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(OwnedRwLockReadGuard { - lock: ManuallyDrop::new(lock), + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } } @@ -144,6 +169,14 @@ where impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockReadGuard<T, U> { fn drop(&mut self) { self.lock.s.release(1); - unsafe { ManuallyDrop::drop(&mut self.lock) }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); } } diff --git a/vendor/tokio/src/sync/rwlock/owned_write_guard.rs b/vendor/tokio/src/sync/rwlock/owned_write_guard.rs index 91b659524..a8ce4a160 100644 --- a/vendor/tokio/src/sync/rwlock/owned_write_guard.rs +++ b/vendor/tokio/src/sync/rwlock/owned_write_guard.rs @@ -1,11 +1,9 @@ use crate::sync::rwlock::owned_read_guard::OwnedRwLockReadGuard; use crate::sync::rwlock::owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; use crate::sync::rwlock::RwLock; -use std::fmt; use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; -use std::ops; use std::sync::Arc; +use std::{fmt, mem, ops, ptr}; /// Owned RAII structure used to release the exclusive write access of a lock when /// dropped. @@ -15,16 +13,44 @@ use std::sync::Arc; /// /// [`write_owned`]: method@crate::sync::RwLock::write_owned /// [`RwLock`]: struct@crate::sync::RwLock +#[clippy::has_significant_drop] pub struct OwnedRwLockWriteGuard<T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, pub(super) permits_acquired: u32, - // ManuallyDrop allows us to destructure into this field without running the destructor. - pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) lock: Arc<RwLock<T>>, pub(super) data: *mut T, pub(super) _p: PhantomData<T>, } +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + permits_acquired: u32, + lock: Arc<RwLock<T>>, + data: *const T, +} + impl<T: ?Sized> OwnedRwLockWriteGuard<T> { - /// Make a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked + fn skip_drop(self) -> Inner<T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + unsafe { + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: ptr::read(&me.resource_span), + permits_acquired: me.permits_acquired, + lock: ptr::read(&me.lock), + data: me.data, + } + } + } + + /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in @@ -62,19 +88,90 @@ impl<T: ?Sized> OwnedRwLockWriteGuard<T> { F: FnOnce(&mut T) -> &mut U, { let data = f(&mut *this) as *mut U; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + OwnedRwLockMappedWriteGuard { - permits_acquired, - lock: ManuallyDrop::new(lock), + permits_acquired: this.permits_acquired, + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } - /// Attempts to make a new [`OwnedRwLockMappedWriteGuard`] for a component + /// Makes a new [`OwnedRwLockReadGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockWriteGuard::downgrade_map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// Inside of `f`, you retain exclusive access to the data, despite only being given a `&T`. Handing out a + /// `&mut T` would result in unsoundness, as you could use interior mutability. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = Arc::clone(&lock).write_owned().await; + /// let mapped = OwnedRwLockWriteGuard::downgrade_map(guard, |f| &f.0); + /// let foo = lock.read_owned().await; + /// assert_eq!(foo.0, *mapped); + /// # } + /// ``` + #[inline] + pub fn downgrade_map<F, U: ?Sized>(this: Self, f: F) -> OwnedRwLockReadGuard<T, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let this = this.skip_drop(); + let guard = OwnedRwLockReadGuard { + lock: this.lock, + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; + + // Release all but one of the permits held by the write guard + let to_release = (this.permits_acquired - 1) as usize; + guard.lock.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + guard + } + + /// Attempts to make a new [`OwnedRwLockMappedWriteGuard`] for a component /// of the locked data. The original guard is returned if the closure /// returns `None`. /// @@ -121,18 +218,99 @@ impl<T: ?Sized> OwnedRwLockWriteGuard<T> { Some(data) => data as *mut U, None => return Err(this), }; - let permits_acquired = this.permits_acquired; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(OwnedRwLockMappedWriteGuard { - permits_acquired, - lock: ManuallyDrop::new(lock), + permits_acquired: this.permits_acquired, + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } + /// Attempts to make a new [`OwnedRwLockReadGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockWriteGuard::try_downgrade_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// Inside of `f`, you retain exclusive access to the data, despite only being given a `&T`. Handing out a + /// `&mut T` would result in unsoundness, as you could use interior mutability. + /// + /// If this function returns `Err(...)`, the lock is never unlocked nor downgraded. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = Arc::clone(&lock).write_owned().await; + /// let guard = OwnedRwLockWriteGuard::try_downgrade_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// let foo = lock.read_owned().await; + /// assert_eq!(foo.0, *guard); + /// # } + /// ``` + #[inline] + pub fn try_downgrade_map<F, U: ?Sized>( + this: Self, + f: F, + ) -> Result<OwnedRwLockReadGuard<T, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let this = this.skip_drop(); + let guard = OwnedRwLockReadGuard { + lock: this.lock, + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; + + // Release all but one of the permits held by the write guard + let to_release = (this.permits_acquired - 1) as usize; + guard.lock.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(guard) + } + /// Converts this `OwnedRwLockWriteGuard` into an /// `OwnedRwLockMappedWriteGuard`. This method can be used to store a /// non-mapped guard in a struct field that expects a mapped guard. @@ -178,19 +356,39 @@ impl<T: ?Sized> OwnedRwLockWriteGuard<T> { /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); /// # } /// ``` - pub fn downgrade(mut self) -> OwnedRwLockReadGuard<T> { - let lock = unsafe { ManuallyDrop::take(&mut self.lock) }; - let data = self.data; + pub fn downgrade(self) -> OwnedRwLockReadGuard<T> { + let this = self.skip_drop(); + let guard = OwnedRwLockReadGuard { + lock: this.lock, + data: this.data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; // Release all but one of the permits held by the write guard - lock.s.release((self.permits_acquired - 1) as usize); - // NB: Forget to avoid drop impl from being called. - mem::forget(self); - OwnedRwLockReadGuard { - lock: ManuallyDrop::new(lock), - data, - _p: PhantomData, - } + let to_release = (this.permits_acquired - 1) as usize; + guard.lock.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + guard } } @@ -229,6 +427,14 @@ where impl<T: ?Sized> Drop for OwnedRwLockWriteGuard<T> { fn drop(&mut self) { self.lock.s.release(self.permits_acquired as usize); - unsafe { ManuallyDrop::drop(&mut self.lock) }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); } } diff --git a/vendor/tokio/src/sync/rwlock/owned_write_guard_mapped.rs b/vendor/tokio/src/sync/rwlock/owned_write_guard_mapped.rs index 6453236eb..9f4952100 100644 --- a/vendor/tokio/src/sync/rwlock/owned_write_guard_mapped.rs +++ b/vendor/tokio/src/sync/rwlock/owned_write_guard_mapped.rs @@ -1,9 +1,7 @@ use crate::sync::rwlock::RwLock; -use std::fmt; use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; -use std::ops; use std::sync::Arc; +use std::{fmt, mem, ops, ptr}; /// Owned RAII structure used to release the exclusive write access of a lock when /// dropped. @@ -14,16 +12,44 @@ use std::sync::Arc; /// /// [mapping]: method@crate::sync::OwnedRwLockWriteGuard::map /// [`OwnedRwLockWriteGuard`]: struct@crate::sync::OwnedRwLockWriteGuard +#[clippy::has_significant_drop] pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, pub(super) permits_acquired: u32, - // ManuallyDrop allows us to destructure into this field without running the destructor. - pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) lock: Arc<RwLock<T>>, pub(super) data: *mut U, pub(super) _p: PhantomData<T>, } +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<T: ?Sized, U: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + permits_acquired: u32, + lock: Arc<RwLock<T>>, + data: *const U, +} + impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { - /// Make a new `OwnedRwLockMappedWriteGuard` for a component of the locked + fn skip_drop(self) -> Inner<T, U> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + unsafe { + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: ptr::read(&me.resource_span), + permits_acquired: me.permits_acquired, + lock: ptr::read(&me.lock), + data: me.data, + } + } + } + + /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed @@ -61,15 +87,15 @@ impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { F: FnOnce(&mut U) -> &mut V, { let data = f(&mut *this) as *mut V; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + OwnedRwLockMappedWriteGuard { - permits_acquired, - lock: ManuallyDrop::new(lock), + permits_acquired: this.permits_acquired, + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } @@ -118,15 +144,15 @@ impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { Some(data) => data as *mut V, None => return Err(this), }; - let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(OwnedRwLockMappedWriteGuard { - permits_acquired, - lock: ManuallyDrop::new(lock), + permits_acquired: this.permits_acquired, + lock: this.lock, data, _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } } @@ -166,6 +192,14 @@ where impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockMappedWriteGuard<T, U> { fn drop(&mut self) { self.lock.s.release(self.permits_acquired as usize); - unsafe { ManuallyDrop::drop(&mut self.lock) }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); } } diff --git a/vendor/tokio/src/sync/rwlock/read_guard.rs b/vendor/tokio/src/sync/rwlock/read_guard.rs index 38eec7727..a04b59588 100644 --- a/vendor/tokio/src/sync/rwlock/read_guard.rs +++ b/vendor/tokio/src/sync/rwlock/read_guard.rs @@ -1,8 +1,6 @@ use crate::sync::batch_semaphore::Semaphore; -use std::fmt; -use std::marker; -use std::mem; -use std::ops; +use std::marker::PhantomData; +use std::{fmt, mem, ops}; /// RAII structure used to release the shared read access of a lock when /// dropped. @@ -12,14 +10,40 @@ use std::ops; /// /// [`read`]: method@crate::sync::RwLock::read /// [`RwLock`]: struct@crate::sync::RwLock +#[clippy::has_significant_drop] +#[must_use = "if unused the RwLock will immediately unlock"] pub struct RwLockReadGuard<'a, T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, pub(super) s: &'a Semaphore, pub(super) data: *const T, - pub(super) marker: marker::PhantomData<&'a T>, + pub(super) marker: PhantomData<&'a T>, +} + +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + s: &'a Semaphore, + data: *const T, } impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { - /// Make a new `RwLockReadGuard` for a component of the locked data. + fn skip_drop(self) -> Inner<'a, T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: unsafe { std::ptr::read(&me.resource_span) }, + s: me.s, + data: me.data, + } + } + + /// Makes a new `RwLockReadGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockReadGuard` passed in already /// locked the data. @@ -58,13 +82,14 @@ impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { F: FnOnce(&T) -> &U, { let data = f(&*this) as *const U; - let s = this.s; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + RwLockReadGuard { - s, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } @@ -112,13 +137,14 @@ impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { Some(data) => data as *const U, None => return Err(this), }; - let s = this.s; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(RwLockReadGuard { - s, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } } @@ -152,5 +178,14 @@ where impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> { fn drop(&mut self) { self.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); } } diff --git a/vendor/tokio/src/sync/rwlock/write_guard.rs b/vendor/tokio/src/sync/rwlock/write_guard.rs index 865a121ed..d405fc2b3 100644 --- a/vendor/tokio/src/sync/rwlock/write_guard.rs +++ b/vendor/tokio/src/sync/rwlock/write_guard.rs @@ -1,10 +1,8 @@ use crate::sync::batch_semaphore::Semaphore; use crate::sync::rwlock::read_guard::RwLockReadGuard; use crate::sync::rwlock::write_guard_mapped::RwLockMappedWriteGuard; -use std::fmt; -use std::marker; -use std::mem; -use std::ops; +use std::marker::PhantomData; +use std::{fmt, mem, ops}; /// RAII structure used to release the exclusive write access of a lock when /// dropped. @@ -14,15 +12,43 @@ use std::ops; /// /// [`write`]: method@crate::sync::RwLock::write /// [`RwLock`]: struct@crate::sync::RwLock +#[clippy::has_significant_drop] +#[must_use = "if unused the RwLock will immediately unlock"] pub struct RwLockWriteGuard<'a, T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, pub(super) permits_acquired: u32, pub(super) s: &'a Semaphore, pub(super) data: *mut T, - pub(super) marker: marker::PhantomData<&'a mut T>, + pub(super) marker: PhantomData<&'a mut T>, +} + +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + permits_acquired: u32, + s: &'a Semaphore, + data: *mut T, } impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { - /// Make a new [`RwLockMappedWriteGuard`] for a component of the locked data. + fn skip_drop(self) -> Inner<'a, T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: unsafe { std::ptr::read(&me.resource_span) }, + permits_acquired: me.permits_acquired, + s: me.s, + data: me.data, + } + } + + /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data. /// /// This operation cannot fail as the `RwLockWriteGuard` passed in already /// locked the data. @@ -64,19 +90,96 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { F: FnOnce(&mut T) -> &mut U, { let data = f(&mut *this) as *mut U; - let s = this.s; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + RwLockMappedWriteGuard { - permits_acquired, - s, + permits_acquired: this.permits_acquired, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } - /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of + /// Makes a new [`RwLockReadGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::downgrade_map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// This is equivalent to a combination of asynchronous [`RwLockWriteGuard::map`] and [`RwLockWriteGuard::downgrade`] + /// from the [`parking_lot` crate]. + /// + /// Inside of `f`, you retain exclusive access to the data, despite only being given a `&T`. Handing out a + /// `&mut T` would result in unsoundness, as you could use interior mutability. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`RwLockWriteGuard::downgrade`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.downgrade + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let mapped = RwLockWriteGuard::downgrade_map(lock.write().await, |f| &f.0); + /// let foo = lock.read().await; + /// assert_eq!(foo.0, *mapped); + /// # } + /// ``` + #[inline] + pub fn downgrade_map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let this = this.skip_drop(); + let guard = RwLockReadGuard { + s: this.s, + data, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; + + // Release all but one of the permits held by the write guard + let to_release = (this.permits_acquired - 1) as usize; + this.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + guard + } + + /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of /// the locked data. The original guard is returned if the closure returns /// `None`. /// @@ -127,18 +230,102 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { Some(data) => data as *mut U, None => return Err(this), }; - let s = this.s; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(RwLockMappedWriteGuard { - permits_acquired, - s, + permits_acquired: this.permits_acquired, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } + /// Attempts to make a new [`RwLockReadGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::try_downgrade_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is equivalent to a combination of asynchronous [`RwLockWriteGuard::try_map`] and [`RwLockWriteGuard::downgrade`] + /// from the [`parking_lot` crate]. + /// + /// Inside of `f`, you retain exclusive access to the data, despite only being given a `&T`. Handing out a + /// `&mut T` would result in unsoundness, as you could use interior mutability. + /// + /// If this function returns `Err(...)`, the lock is never unlocked nor downgraded. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`RwLockWriteGuard::downgrade`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.downgrade + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = RwLockWriteGuard::try_downgrade_map(lock.write().await, |f| Some(&f.0)).expect("should not fail"); + /// let foo = lock.read().await; + /// assert_eq!(foo.0, *guard); + /// # } + /// ``` + #[inline] + pub fn try_downgrade_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let this = this.skip_drop(); + let guard = RwLockReadGuard { + s: this.s, + data, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; + + // Release all but one of the permits held by the write guard + let to_release = (this.permits_acquired - 1) as usize; + this.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(guard) + } + /// Converts this `RwLockWriteGuard` into an `RwLockMappedWriteGuard`. This /// method can be used to store a non-mapped guard in a struct field that /// expects a mapped guard. @@ -187,17 +374,38 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { /// /// [`RwLock`]: struct@crate::sync::RwLock pub fn downgrade(self) -> RwLockReadGuard<'a, T> { - let RwLockWriteGuard { s, data, .. } = self; + let this = self.skip_drop(); + let guard = RwLockReadGuard { + s: this.s, + data: this.data, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, + }; // Release all but one of the permits held by the write guard - s.release((self.permits_acquired - 1) as usize); - // NB: Forget to avoid drop impl from being called. - mem::forget(self); - RwLockReadGuard { - s, - data, - marker: marker::PhantomData, - } + let to_release = (this.permits_acquired - 1) as usize; + this.s.release(to_release); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + guard.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + guard } } @@ -236,5 +444,14 @@ where impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { fn drop(&mut self) { self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); } } diff --git a/vendor/tokio/src/sync/rwlock/write_guard_mapped.rs b/vendor/tokio/src/sync/rwlock/write_guard_mapped.rs index 9c5b1e7c3..7705189e7 100644 --- a/vendor/tokio/src/sync/rwlock/write_guard_mapped.rs +++ b/vendor/tokio/src/sync/rwlock/write_guard_mapped.rs @@ -1,8 +1,6 @@ use crate::sync::batch_semaphore::Semaphore; -use std::fmt; -use std::marker; -use std::mem; -use std::ops; +use std::marker::PhantomData; +use std::{fmt, mem, ops}; /// RAII structure used to release the exclusive write access of a lock when /// dropped. @@ -13,15 +11,42 @@ use std::ops; /// /// [mapping]: method@crate::sync::RwLockWriteGuard::map /// [`RwLockWriteGuard`]: struct@crate::sync::RwLockWriteGuard +#[clippy::has_significant_drop] pub struct RwLockMappedWriteGuard<'a, T: ?Sized> { + // When changing the fields in this struct, make sure to update the + // `skip_drop` method. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, pub(super) permits_acquired: u32, pub(super) s: &'a Semaphore, pub(super) data: *mut T, - pub(super) marker: marker::PhantomData<&'a mut T>, + pub(super) marker: PhantomData<&'a mut T>, +} + +#[allow(dead_code)] // Unused fields are still used in Drop. +struct Inner<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + permits_acquired: u32, + s: &'a Semaphore, + data: *mut T, } impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { - /// Make a new `RwLockMappedWriteGuard` for a component of the locked data. + fn skip_drop(self) -> Inner<'a, T> { + let me = mem::ManuallyDrop::new(self); + // SAFETY: This duplicates the values in every field of the guard, then + // forgets the originals, so in the end no value is duplicated. + Inner { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: unsafe { std::ptr::read(&me.resource_span) }, + permits_acquired: me.permits_acquired, + s: me.s, + data: me.data, + } + } + + /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already /// locked the data. @@ -62,15 +87,15 @@ impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { F: FnOnce(&mut T) -> &mut U, { let data = f(&mut *this) as *mut U; - let s = this.s; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + RwLockMappedWriteGuard { - permits_acquired, - s, + permits_acquired: this.permits_acquired, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, } } @@ -124,17 +149,20 @@ impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { Some(data) => data as *mut U, None => return Err(this), }; - let s = this.s; - let permits_acquired = this.permits_acquired; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); + let this = this.skip_drop(); + Ok(RwLockMappedWriteGuard { - permits_acquired, - s, + permits_acquired: this.permits_acquired, + s: this.s, data, - marker: marker::PhantomData, + marker: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: this.resource_span, }) } + + // Note: No `downgrade`, `downgrade_map` nor `try_downgrade_map` because they would be unsound, as we're already + // potentially been mapped with internal mutability. } impl<T: ?Sized> ops::Deref for RwLockMappedWriteGuard<'_, T> { @@ -172,5 +200,14 @@ where impl<'a, T: ?Sized> Drop for RwLockMappedWriteGuard<'a, T> { fn drop(&mut self) { self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); } } diff --git a/vendor/tokio/src/sync/semaphore.rs b/vendor/tokio/src/sync/semaphore.rs index 4b697a9bf..e679d0e6b 100644 --- a/vendor/tokio/src/sync/semaphore.rs +++ b/vendor/tokio/src/sync/semaphore.rs @@ -1,5 +1,7 @@ use super::batch_semaphore as ll; // low level implementation use super::{AcquireError, TryAcquireError}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; use std::sync::Arc; /// Counting semaphore performing asynchronous permit acquisition. @@ -71,12 +73,14 @@ use std::sync::Arc; /// } /// ``` /// -/// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html +/// [`PollSemaphore`]: https://docs.rs/tokio-util/latest/tokio_util/sync/struct.PollSemaphore.html /// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned #[derive(Debug)] pub struct Semaphore { /// The low level semaphore ll_sem: ll::Semaphore, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, } /// A permit from the semaphore. @@ -85,6 +89,7 @@ pub struct Semaphore { /// /// [`acquire`]: crate::sync::Semaphore::acquire() #[must_use] +#[clippy::has_significant_drop] #[derive(Debug)] pub struct SemaphorePermit<'a> { sem: &'a Semaphore, @@ -97,6 +102,7 @@ pub struct SemaphorePermit<'a> { /// /// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() #[must_use] +#[clippy::has_significant_drop] #[derive(Debug)] pub struct OwnedSemaphorePermit { sem: Arc<Semaphore>, @@ -119,10 +125,41 @@ fn bounds() { } impl Semaphore { + /// The maximum number of permits which a semaphore can hold. It is `usize::MAX >> 3`. + /// + /// Exceeding this limit typically results in a panic. + pub const MAX_PERMITS: usize = super::batch_semaphore::Semaphore::MAX_PERMITS; + /// Creates a new semaphore with the initial number of permits. + /// + /// Panics if `permits` exceeds [`Semaphore::MAX_PERMITS`]. + #[track_caller] pub fn new(permits: usize) -> Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + inherits_child_attrs = true, + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let ll_sem = resource_span.in_scope(|| ll::Semaphore::new(permits)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let ll_sem = ll::Semaphore::new(permits); + Self { - ll_sem: ll::Semaphore::new(permits), + ll_sem, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -139,9 +176,16 @@ impl Semaphore { #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new(permits: usize) -> Self { - Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return Self { ll_sem: ll::Semaphore::const_new(permits), - } + resource_span: tracing::Span::none(), + }; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { + ll_sem: ll::Semaphore::const_new(permits), + }; } /// Returns the current number of available permits. @@ -151,7 +195,7 @@ impl Semaphore { /// Adds `n` new permits to the semaphore. /// - /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function will panic if the limit is exceeded. pub fn add_permits(&self, n: usize) { self.ll_sem.release(n); } @@ -191,9 +235,20 @@ impl Semaphore { /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { - self.ll_sem.acquire(1).await?; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; Ok(SemaphorePermit { - sem: &self, + sem: self, permits: 1, }) } @@ -227,9 +282,21 @@ impl Semaphore { /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many", + "poll", + true, + ) + .await?; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] self.ll_sem.acquire(n).await?; + Ok(SemaphorePermit { - sem: &self, + sem: self, permits: n, }) } @@ -350,7 +417,18 @@ impl Semaphore { /// [`AcquireError`]: crate::sync::AcquireError /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit pub async fn acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, AcquireError> { - self.ll_sem.acquire(1).await?; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; Ok(OwnedSemaphorePermit { sem: self, permits: 1, @@ -403,7 +481,18 @@ impl Semaphore { self: Arc<Self>, n: u32, ) -> Result<OwnedSemaphorePermit, AcquireError> { - self.ll_sem.acquire(n).await?; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(n); + + inner.await?; Ok(OwnedSemaphorePermit { sem: self, permits: n, @@ -540,6 +629,25 @@ impl<'a> SemaphorePermit<'a> { pub fn forget(mut self) { self.permits = 0; } + + /// Merge two [`SemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + /// + /// # Panics + /// + /// This function panics if permits from different [`Semaphore`] instances + /// are merged. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + std::ptr::eq(self.sem, other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } } impl OwnedSemaphorePermit { @@ -549,9 +657,33 @@ impl OwnedSemaphorePermit { pub fn forget(mut self) { self.permits = 0; } + + /// Merge two [`OwnedSemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + /// + /// # Panics + /// + /// This function panics if permits from different [`Semaphore`] instances + /// are merged. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + Arc::ptr_eq(&self.sem, &other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } + + /// Returns the [`Semaphore`] from which this permit was acquired. + pub fn semaphore(&self) -> &Arc<Semaphore> { + &self.sem + } } -impl<'a> Drop for SemaphorePermit<'_> { +impl Drop for SemaphorePermit<'_> { fn drop(&mut self) { self.sem.add_permits(self.permits as usize); } diff --git a/vendor/tokio/src/sync/task/atomic_waker.rs b/vendor/tokio/src/sync/task/atomic_waker.rs index 8616007a3..13aba3544 100644 --- a/vendor/tokio/src/sync/task/atomic_waker.rs +++ b/vendor/tokio/src/sync/task/atomic_waker.rs @@ -1,9 +1,11 @@ #![cfg_attr(any(loom, not(feature = "sync")), allow(dead_code, unreachable_pub))] use crate::loom::cell::UnsafeCell; -use crate::loom::sync::atomic::{self, AtomicUsize}; +use crate::loom::hint; +use crate::loom::sync::atomic::AtomicUsize; use std::fmt; +use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::Waker; @@ -27,6 +29,9 @@ pub(crate) struct AtomicWaker { waker: UnsafeCell<Option<Waker>>, } +impl RefUnwindSafe for AtomicWaker {} +impl UnwindSafe for AtomicWaker {} + // `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell // stores a `Waker` value produced by calls to `register` and many threads can // race to take the waker by calling `wake`. @@ -84,7 +89,7 @@ pub(crate) struct AtomicWaker { // back to `WAITING`. This transition must succeed as, at this point, the state // cannot be transitioned by another thread. // -// If the thread is unable to obtain the lock, the `WAKING` bit is still. +// If the thread is unable to obtain the lock, the `WAKING` bit is still set. // This is because it has either been set by the current thread but the previous // value included the `REGISTERING` bit **or** a concurrent thread is in the // `WAKING` critical section. Either way, no action must be taken. @@ -123,7 +128,7 @@ pub(crate) struct AtomicWaker { // Thread A still holds the `wake` lock, the call to `register` will result // in the task waking itself and get scheduled again. -/// Idle state +/// Idle state. const WAITING: usize = 0; /// A new waker value is being registered with the `AtomicWaker` cell. @@ -171,6 +176,10 @@ impl AtomicWaker { where W: WakerRef, { + fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> { + std::panic::catch_unwind(AssertUnwindSafe(f)) + } + match self .state .compare_exchange(WAITING, REGISTERING, Acquire, Acquire) @@ -178,8 +187,24 @@ impl AtomicWaker { { WAITING => { unsafe { - // Locked acquired, update the waker cell - self.waker.with_mut(|t| *t = Some(waker.into_waker())); + // If `into_waker` panics (because it's code outside of + // AtomicWaker) we need to prime a guard that is called on + // unwind to restore the waker to a WAITING state. Otherwise + // any future calls to register will incorrectly be stuck + // believing it's being updated by someone else. + let new_waker_or_panic = catch_unwind(move || waker.into_waker()); + + // Set the field to contain the new waker, or if + // `into_waker` panicked, leave the old value. + let mut maybe_panic = None; + let mut old_waker = None; + match new_waker_or_panic { + Ok(new_waker) => { + old_waker = self.waker.with_mut(|t| (*t).take()); + self.waker.with_mut(|t| *t = Some(new_waker)); + } + Err(panic) => maybe_panic = Some(panic), + } // Release the lock. If the state transitioned to include // the `WAKING` bit, this means that a wake has been @@ -193,39 +218,71 @@ impl AtomicWaker { .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); match res { - Ok(_) => {} + Ok(_) => { + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + let _ = catch_unwind(move || { + drop(old_waker); + }); + } Err(actual) => { // This branch can only be reached if a // concurrent thread called `wake`. In this // case, `actual` **must** be `REGISTERING | - // `WAKING`. + // WAKING`. debug_assert_eq!(actual, REGISTERING | WAKING); // Take the waker to wake once the atomic operation has // completed. - let waker = self.waker.with_mut(|t| (*t).take()).unwrap(); + let mut waker = self.waker.with_mut(|t| (*t).take()); // Just swap, because no one could change state // while state == `Registering | `Waking` self.state.swap(WAITING, AcqRel); - // The atomic swap was complete, now - // wake the waker and return. - waker.wake(); + // If `into_waker` panicked, then the waker in the + // waker slot is actually the old waker. + if maybe_panic.is_some() { + old_waker = waker.take(); + } + + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + if let Some(old_waker) = old_waker { + let _ = catch_unwind(move || { + old_waker.wake(); + }); + } + + // The atomic swap was complete, now wake the waker + // and return. + // + // If this panics, we end up in a consumed state and + // return the panic to the caller. + if let Some(waker) = waker { + debug_assert!(maybe_panic.is_none()); + waker.wake(); + } } } + + if let Some(panic) = maybe_panic { + // If `into_waker` panicked, return the panic to the caller. + resume_unwind(panic); + } } } WAKING => { // Currently in the process of waking the task, i.e., // `wake` is currently being called on the old waker. // So, we call wake on the new waker. + // + // If this panics, someone else is responsible for restoring the + // state of the waker. waker.wake(); // This is equivalent to a spin lock, so use a spin hint. - // TODO: once we bump MSRV to 1.49+, use `hint::spin_loop` instead. - #[allow(deprecated)] - atomic::spin_loop_hint(); + hint::spin_loop(); } state => { // In this case, a concurrent thread is holding the @@ -245,6 +302,8 @@ impl AtomicWaker { /// If `register` has not been called yet, then this does nothing. pub(crate) fn wake(&self) { if let Some(waker) = self.take_waker() { + // If wake panics, we've consumed the waker which is a legitimate + // outcome. waker.wake(); } } diff --git a/vendor/tokio/src/sync/tests/atomic_waker.rs b/vendor/tokio/src/sync/tests/atomic_waker.rs index c832d62e9..8ebfb915f 100644 --- a/vendor/tokio/src/sync/tests/atomic_waker.rs +++ b/vendor/tokio/src/sync/tests/atomic_waker.rs @@ -4,7 +4,7 @@ use tokio_test::task; use std::task::Waker; trait AssertSend: Send {} -trait AssertSync: Send {} +trait AssertSync: Sync {} impl AssertSend for AtomicWaker {} impl AssertSync for AtomicWaker {} @@ -12,6 +12,9 @@ impl AssertSync for AtomicWaker {} impl AssertSend for Waker {} impl AssertSync for Waker {} +#[cfg(tokio_wasm_not_wasi)] +use wasm_bindgen_test::wasm_bindgen_test as test; + #[test] fn basic_usage() { let mut waker = task::spawn(AtomicWaker::new()); @@ -32,3 +35,43 @@ fn wake_without_register() { assert!(!waker.is_woken()); } + +#[test] +#[cfg(not(tokio_wasm))] // wasm currently doesn't support unwinding +fn atomic_waker_panic_safe() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| panic!("clone"), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) }; + + let atomic_waker = AtomicWaker::new(); + + let panicking = panic::AssertUnwindSafe(&panicking); + + let result = panic::catch_unwind(|| { + let panic::AssertUnwindSafe(panicking) = panicking; + atomic_waker.register_by_ref(panicking); + }); + + assert!(result.is_err()); + assert!(atomic_waker.take_waker().is_none()); + + atomic_waker.register_by_ref(&nonpanicking); + assert!(atomic_waker.take_waker().is_some()); +} diff --git a/vendor/tokio/src/sync/tests/loom_atomic_waker.rs b/vendor/tokio/src/sync/tests/loom_atomic_waker.rs index c148bcbe1..f8bae65d1 100644 --- a/vendor/tokio/src/sync/tests/loom_atomic_waker.rs +++ b/vendor/tokio/src/sync/tests/loom_atomic_waker.rs @@ -43,3 +43,58 @@ fn basic_notification() { })); }); } + +#[test] +fn test_panicky_waker() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = + RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ()); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + + // If you're working with this test (and I sure hope you never have to!), + // uncomment the following section because there will be a lot of panics + // which would otherwise log. + // + // We can't however leaved it uncommented, because it's global. + // panic::set_hook(Box::new(|_| ())); + + const NUM_NOTIFY: usize = 2; + + loom::model(move || { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + // Note: this panic should have no effect on the overall state of the + // waker and it should proceed as normal. + // + // A thread above might race to flag a wakeup, and a WAKING state will + // be preserved if this expected panic races with that so the below + // procedure should be allowed to continue uninterrupted. + let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking)); + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/vendor/tokio/src/sync/tests/loom_mpsc.rs b/vendor/tokio/src/sync/tests/loom_mpsc.rs index c12313bd3..f165e7076 100644 --- a/vendor/tokio/src/sync/tests/loom_mpsc.rs +++ b/vendor/tokio/src/sync/tests/loom_mpsc.rs @@ -132,3 +132,59 @@ fn dropping_unbounded_tx() { assert!(v.is_none()); }); } + +#[test] +fn try_recv() { + loom::model(|| { + use crate::sync::{mpsc, Semaphore}; + use loom::sync::{Arc, Mutex}; + + const PERMITS: usize = 2; + const TASKS: usize = 2; + const CYCLES: usize = 1; + + struct Context { + sem: Arc<Semaphore>, + tx: mpsc::Sender<()>, + rx: Mutex<mpsc::Receiver<()>>, + } + + fn run(ctx: &Context) { + block_on(async { + let permit = ctx.sem.acquire().await; + assert_ok!(ctx.rx.lock().unwrap().try_recv()); + crate::task::yield_now().await; + assert_ok!(ctx.tx.clone().try_send(())); + drop(permit); + }); + } + + let (tx, rx) = mpsc::channel(PERMITS); + let sem = Arc::new(Semaphore::new(PERMITS)); + let ctx = Arc::new(Context { + sem, + tx, + rx: Mutex::new(rx), + }); + + for _ in 0..PERMITS { + assert_ok!(ctx.tx.clone().try_send(())); + } + + let mut ths = Vec::new(); + + for _ in 0..TASKS { + let ctx = ctx.clone(); + + ths.push(thread::spawn(move || { + run(&ctx); + })); + } + + run(&ctx); + + for th in ths { + th.join().unwrap(); + } + }); +} diff --git a/vendor/tokio/src/sync/tests/loom_notify.rs b/vendor/tokio/src/sync/tests/loom_notify.rs index d484a7581..a4ded1d35 100644 --- a/vendor/tokio/src/sync/tests/loom_notify.rs +++ b/vendor/tokio/src/sync/tests/loom_notify.rs @@ -4,6 +4,11 @@ use loom::future::block_on; use loom::sync::Arc; use loom::thread; +use tokio_test::{assert_pending, assert_ready}; + +/// `util::wake_list::NUM_WAKERS` +const WAKE_LIST_SIZE: usize = 32; + #[test] fn notify_one() { loom::model(|| { @@ -138,3 +143,189 @@ fn notify_drop() { th2.join().unwrap(); }); } + +/// Polls two `Notified` futures and checks if poll results are consistent +/// with each other. If the first future is notified by a `notify_waiters` +/// call, then the second one must be notified as well. +#[test] +fn notify_waiters_poll_consistency() { + fn notify_waiters_poll_consistency_variant(poll_setting: [bool; 2]) { + let notify = Arc::new(Notify::new()); + let mut notified = [ + tokio_test::task::spawn(notify.notified()), + tokio_test::task::spawn(notify.notified()), + ]; + for i in 0..2 { + if poll_setting[i] { + assert_pending!(notified[i].poll()); + } + } + + let tx = notify.clone(); + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + let res1 = notified[0].poll(); + let res2 = notified[1].poll(); + + // If res1 is ready, then res2 must also be ready. + assert!(res1.is_pending() || res2.is_ready()); + + th.join().unwrap(); + } + + // We test different scenarios in which pending futures had or had not + // been polled before the call to `notify_waiters`. + loom::model(|| notify_waiters_poll_consistency_variant([false, false])); + loom::model(|| notify_waiters_poll_consistency_variant([true, false])); + loom::model(|| notify_waiters_poll_consistency_variant([false, true])); + loom::model(|| notify_waiters_poll_consistency_variant([true, true])); +} + +/// Polls two `Notified` futures and checks if poll results are consistent +/// with each other. If the first future is notified by a `notify_waiters` +/// call, then the second one must be notified as well. +/// +/// Here we also add other `Notified` futures in between to force the two +/// tested futures to end up in different chunks. +#[test] +fn notify_waiters_poll_consistency_many() { + fn notify_waiters_poll_consistency_many_variant(order: [usize; 2]) { + let notify = Arc::new(Notify::new()); + + let mut futs = (0..WAKE_LIST_SIZE + 1) + .map(|_| tokio_test::task::spawn(notify.notified())) + .collect::<Vec<_>>(); + + assert_pending!(futs[order[0]].poll()); + for i in 2..futs.len() { + assert_pending!(futs[i].poll()); + } + assert_pending!(futs[order[1]].poll()); + + let tx = notify.clone(); + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + let res1 = futs[0].poll(); + let res2 = futs[1].poll(); + + // If res1 is ready, then res2 must also be ready. + assert!(res1.is_pending() || res2.is_ready()); + + th.join().unwrap(); + } + + // We test different scenarios in which futures are polled in different order. + loom::model(|| notify_waiters_poll_consistency_many_variant([0, 1])); + loom::model(|| notify_waiters_poll_consistency_many_variant([1, 0])); +} + +/// Checks if a call to `notify_waiters` is observed as atomic when combined +/// with a concurrent call to `notify_one`. +#[test] +fn notify_waiters_is_atomic() { + fn notify_waiters_is_atomic_variant(tested_fut_index: usize) { + let notify = Arc::new(Notify::new()); + + let mut futs = (0..WAKE_LIST_SIZE + 1) + .map(|_| tokio_test::task::spawn(notify.notified())) + .collect::<Vec<_>>(); + + for fut in &mut futs { + assert_pending!(fut.poll()); + } + + let tx = notify.clone(); + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + block_on(async { + // If awaiting one of the futures completes, then we should be + // able to assume that all pending futures are notified. Therefore + // a notification from a subsequent `notify_one` call should not + // be consumed by an old future. + futs.remove(tested_fut_index).await; + + let mut new_fut = tokio_test::task::spawn(notify.notified()); + assert_pending!(new_fut.poll()); + + notify.notify_one(); + + // `new_fut` must consume the notification from `notify_one`. + assert_ready!(new_fut.poll()); + }); + + th.join().unwrap(); + } + + // We test different scenarios in which the tested future is at the beginning + // or at the end of the waiters queue used by `Notify`. + loom::model(|| notify_waiters_is_atomic_variant(0)); + loom::model(|| notify_waiters_is_atomic_variant(32)); +} + +/// Checks if a single call to `notify_waiters` does not get through two `Notified` +/// futures created and awaited sequentially like this: +/// ```ignore +/// notify.notified().await; +/// notify.notified().await; +/// ``` +#[test] +fn notify_waiters_sequential_notified_await() { + use crate::sync::oneshot; + + loom::model(|| { + let notify = Arc::new(Notify::new()); + + let (tx_fst, rx_fst) = oneshot::channel(); + let (tx_snd, rx_snd) = oneshot::channel(); + + let receiver = thread::spawn({ + let notify = notify.clone(); + move || { + block_on(async { + // Poll the first `Notified` to put it as the first waiter + // in the queue. + let mut first_notified = tokio_test::task::spawn(notify.notified()); + assert_pending!(first_notified.poll()); + + // Create additional waiters to force `notify_waiters` to + // release the lock at least once. + let _task_pile = (0..WAKE_LIST_SIZE + 1) + .map(|_| { + let mut fut = tokio_test::task::spawn(notify.notified()); + assert_pending!(fut.poll()); + fut + }) + .collect::<Vec<_>>(); + + // We are ready for the notify_waiters call. + tx_fst.send(()).unwrap(); + + first_notified.await; + + // Poll the second `Notified` future to try to insert + // it to the waiters queue. + let mut second_notified = tokio_test::task::spawn(notify.notified()); + assert_pending!(second_notified.poll()); + + // Wait for the `notify_waiters` to end and check if we + // are woken up. + rx_snd.await.unwrap(); + assert_pending!(second_notified.poll()); + }); + } + }); + + // Wait for the signal and call `notify_waiters`. + block_on(rx_fst).unwrap(); + notify.notify_waiters(); + tx_snd.send(()).unwrap(); + + receiver.join().unwrap(); + }); +} diff --git a/vendor/tokio/src/sync/tests/loom_watch.rs b/vendor/tokio/src/sync/tests/loom_watch.rs index c575b5b66..51589cd80 100644 --- a/vendor/tokio/src/sync/tests/loom_watch.rs +++ b/vendor/tokio/src/sync/tests/loom_watch.rs @@ -2,6 +2,7 @@ use crate::sync::watch; use loom::future::block_on; use loom::thread; +use std::sync::Arc; #[test] fn smoke() { @@ -34,3 +35,56 @@ fn smoke() { th.join().unwrap(); }) } + +#[test] +fn wait_for_test() { + loom::model(move || { + let (tx, mut rx) = watch::channel(false); + + let tx_arc = Arc::new(tx); + let tx1 = tx_arc.clone(); + let tx2 = tx_arc.clone(); + + let th1 = thread::spawn(move || { + for _ in 0..2 { + tx1.send_modify(|_x| {}); + } + }); + + let th2 = thread::spawn(move || { + tx2.send(true).unwrap(); + }); + + assert_eq!(*block_on(rx.wait_for(|x| *x)).unwrap(), true); + + th1.join().unwrap(); + th2.join().unwrap(); + }); +} + +#[test] +fn wait_for_returns_correct_value() { + loom::model(move || { + let (tx, mut rx) = watch::channel(0); + + let jh = thread::spawn(move || { + tx.send(1).unwrap(); + tx.send(2).unwrap(); + tx.send(3).unwrap(); + }); + + // Stop at the first value we are called at. + let mut stopped_at = usize::MAX; + let returned = *block_on(rx.wait_for(|x| { + stopped_at = *x; + true + })) + .unwrap(); + + // Check that it returned the same value as the one we returned + // `true` for. + assert_eq!(stopped_at, returned); + + jh.join().unwrap(); + }); +} diff --git a/vendor/tokio/src/sync/tests/mod.rs b/vendor/tokio/src/sync/tests/mod.rs index c5d560196..ee76418ac 100644 --- a/vendor/tokio/src/sync/tests/mod.rs +++ b/vendor/tokio/src/sync/tests/mod.rs @@ -1,5 +1,6 @@ cfg_not_loom! { mod atomic_waker; + mod notify; mod semaphore_batch; } diff --git a/vendor/tokio/src/sync/tests/notify.rs b/vendor/tokio/src/sync/tests/notify.rs new file mode 100644 index 000000000..4b5989597 --- /dev/null +++ b/vendor/tokio/src/sync/tests/notify.rs @@ -0,0 +1,120 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[cfg(tokio_wasm_not_wasi)] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::<Notify>::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +} + +#[cfg(panic = "unwind")] +#[test] +fn notify_waiters_handles_panicking_waker() { + use futures::task::ArcWake; + + let notify = Arc::new(Notify::new()); + + struct PanickingWaker(Arc<Notify>); + + impl ArcWake for PanickingWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) { + panic!("waker panicked"); + } + } + + let bad_fut = notify.notified(); + pin!(bad_fut); + + let waker = futures::task::waker(Arc::new(PanickingWaker(notify.clone()))); + let mut cx = Context::from_waker(&waker); + let _ = bad_fut.poll(&mut cx); + + let mut futs = Vec::new(); + for _ in 0..32 { + let mut fut = tokio_test::task::spawn(notify.notified()); + assert!(fut.poll().is_pending()); + futs.push(fut); + } + + assert!(std::panic::catch_unwind(|| { + notify.notify_waiters(); + }) + .is_err()); + + for mut fut in futs { + assert!(fut.poll().is_ready()); + } +} + +#[test] +fn notify_simple() { + let notify = Notify::new(); + + let mut fut1 = tokio_test::task::spawn(notify.notified()); + assert!(fut1.poll().is_pending()); + + let mut fut2 = tokio_test::task::spawn(notify.notified()); + assert!(fut2.poll().is_pending()); + + notify.notify_waiters(); + + assert!(fut1.poll().is_ready()); + assert!(fut2.poll().is_ready()); +} + +#[test] +#[cfg(not(tokio_wasm))] +fn watch_test() { + let rt = crate::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + rt.block_on(async { + let (tx, mut rx) = crate::sync::watch::channel(()); + + crate::spawn(async move { + let _ = tx.send(()); + }); + + let _ = rx.changed().await; + }); +} diff --git a/vendor/tokio/src/sync/tests/semaphore_batch.rs b/vendor/tokio/src/sync/tests/semaphore_batch.rs index 9342cd1cb..85089cd22 100644 --- a/vendor/tokio/src/sync/tests/semaphore_batch.rs +++ b/vendor/tokio/src/sync/tests/semaphore_batch.rs @@ -1,6 +1,11 @@ use crate::sync::batch_semaphore::Semaphore; use tokio_test::*; +const MAX_PERMITS: usize = crate::sync::Semaphore::MAX_PERMITS; + +#[cfg(tokio_wasm_not_wasi)] +use wasm_bindgen_test::wasm_bindgen_test as test; + #[test] fn poll_acquire_one_available() { let s = Semaphore::new(100); @@ -166,10 +171,15 @@ fn poll_acquire_one_zero_permits() { } #[test] +fn max_permits_doesnt_panic() { + Semaphore::new(MAX_PERMITS); +} + +#[test] #[should_panic] +#[cfg(not(tokio_wasm))] // wasm currently doesn't support unwinding fn validates_max_permits() { - use std::usize; - Semaphore::new((usize::MAX >> 2) + 1); + Semaphore::new(MAX_PERMITS + 1); } #[test] @@ -248,3 +258,32 @@ fn cancel_acquire_releases_permits() { assert_eq!(6, s.available_permits()); assert_ok!(s.try_acquire(6)); } + +#[test] +fn release_permits_at_drop() { + use crate::sync::semaphore::*; + use futures::task::ArcWake; + use std::future::Future; + use std::sync::Arc; + + let sem = Arc::new(Semaphore::new(1)); + + struct ReleaseOnDrop(Option<OwnedSemaphorePermit>); + + impl ArcWake for ReleaseOnDrop { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + let mut fut = Box::pin(async { + let _permit = sem.acquire().await.unwrap(); + }); + + // Second iteration shouldn't deadlock. + for _ in 0..=1 { + let waker = futures::task::waker(Arc::new(ReleaseOnDrop( + sem.clone().try_acquire_owned().ok(), + ))); + let mut cx = std::task::Context::from_waker(&waker); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + } +} diff --git a/vendor/tokio/src/sync/watch.rs b/vendor/tokio/src/sync/watch.rs index 7852b0cb1..61049b71e 100644 --- a/vendor/tokio/src/sync/watch.rs +++ b/vendor/tokio/src/sync/watch.rs @@ -9,10 +9,10 @@ //! # Usage //! //! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer -//! and sender halves of the channel. The channel is created with an initial +//! and consumer halves of the channel. The channel is created with an initial //! value. The **latest** value stored in the channel is accessed with //! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new -//! value to sent by the [`Sender`] half. +//! value to be sent by the [`Sender`] half. //! //! # Examples //! @@ -20,15 +20,15 @@ //! use tokio::sync::watch; //! //! # async fn dox() -> Result<(), Box<dyn std::error::Error>> { -//! let (tx, mut rx) = watch::channel("hello"); +//! let (tx, mut rx) = watch::channel("hello"); //! -//! tokio::spawn(async move { -//! while rx.changed().await.is_ok() { -//! println!("received = {:?}", *rx.borrow()); -//! } -//! }); +//! tokio::spawn(async move { +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); +//! } +//! }); //! -//! tx.send("world")?; +//! tx.send("world")?; //! # Ok(()) //! # } //! ``` @@ -39,6 +39,9 @@ //! when all [`Receiver`] handles have been dropped. This indicates that there //! is no further interest in the values being produced and work can be stopped. //! +//! The value in the channel will not be dropped until the sender and all receivers +//! have been dropped. +//! //! # Thread safety //! //! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other @@ -56,9 +59,12 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst}; +use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use std::fmt; +use std::mem; use std::ops; +use std::panic; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -74,7 +80,7 @@ pub struct Receiver<T> { shared: Arc<Shared<T>>, /// Last observed version - version: usize, + version: Version, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -85,60 +91,144 @@ pub struct Sender<T> { shared: Arc<Shared<T>>, } -/// Returns a reference to the inner value +/// Returns a reference to the inner value. /// /// Outstanding borrows hold a read lock on the inner value. This means that -/// long lived borrows could cause the produce half to block. It is recommended -/// to keep the borrow as short lived as possible. +/// long-lived borrows could cause the producer half to block. It is recommended +/// to keep the borrow as short-lived as possible. Additionally, if you are +/// running in an environment that allows `!Send` futures, you must ensure that +/// the returned `Ref` type is never held alive across an `.await` point, +/// otherwise, it can lead to a deadlock. +/// +/// The priority policy of the lock is dependent on the underlying lock +/// implementation, and this type does not guarantee that any particular policy +/// will be used. In particular, a producer which is waiting to acquire the lock +/// in `send` might or might not block concurrent calls to `borrow`, e.g.: +/// +/// <details><summary>Potential deadlock example</summary> +/// +/// ```text +/// // Task 1 (on thread A) | // Task 2 (on thread B) +/// let _ref1 = rx.borrow(); | +/// | // will block +/// | let _ = tx.send(()); +/// // may deadlock | +/// let _ref2 = rx.borrow(); | +/// ``` +/// </details> #[derive(Debug)] pub struct Ref<'a, T> { inner: RwLockReadGuard<'a, T>, + has_changed: bool, +} + +impl<'a, T> Ref<'a, T> { + /// Indicates if the borrowed value is considered as _changed_ since the last + /// time it has been marked as seen. + /// + /// Unlike [`Receiver::has_changed()`], this method does not fail if the channel is closed. + /// + /// When borrowed from the [`Sender`] this function will always return `false`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// // The sender does never consider the value as changed. + /// assert!(!tx.borrow().has_changed()); + /// + /// // Drop the sender immediately, just for testing purposes. + /// drop(tx); + /// + /// // Even if the sender has already been dropped... + /// assert!(rx.has_changed().is_err()); + /// // ...the modified value is still readable and detected as changed. + /// assert_eq!(*rx.borrow(), "goodbye"); + /// assert!(rx.borrow().has_changed()); + /// + /// // Read the changed value and mark it as seen. + /// { + /// let received = rx.borrow_and_update(); + /// assert_eq!(*received, "goodbye"); + /// assert!(received.has_changed()); + /// // Release the read lock when leaving this scope. + /// } + /// + /// // Now the value has already been marked as seen and could + /// // never be modified again (after the sender has been dropped). + /// assert!(!rx.borrow().has_changed()); + /// } + /// ``` + pub fn has_changed(&self) -> bool { + self.has_changed + } } -#[derive(Debug)] struct Shared<T> { - /// The most recent value + /// The most recent value. value: RwLock<T>, - /// The current version + /// The current version. /// /// The lowest bit represents a "closed" state. The rest of the bits /// represent the current version. - version: AtomicUsize, + state: AtomicState, - /// Tracks the number of `Receiver` instances + /// Tracks the number of `Receiver` instances. ref_count_rx: AtomicUsize, /// Notifies waiting receivers that the value changed. - notify_rx: Notify, + notify_rx: big_notify::BigNotify, - /// Notifies any task listening for `Receiver` dropped events + /// Notifies any task listening for `Receiver` dropped events. notify_tx: Notify, } +impl<T: fmt::Debug> fmt::Debug for Shared<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.state.load(); + f.debug_struct("Shared") + .field("value", &self.value) + .field("version", &state.version()) + .field("is_closed", &state.is_closed()) + .field("ref_count_rx", &self.ref_count_rx) + .finish() + } +} + pub mod error { - //! Watch error types + //! Watch error types. use std::fmt; /// Error produced when sending a value fails. - #[derive(Debug)] - pub struct SendError<T> { - pub(crate) inner: T, - } + #[derive(PartialEq, Eq, Clone, Copy)] + pub struct SendError<T>(pub T); // ===== impl SendError ===== - impl<T: fmt::Debug> fmt::Display for SendError<T> { + impl<T> fmt::Debug for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendError").finish_non_exhaustive() + } + } + + impl<T> fmt::Display for SendError<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } - impl<T: fmt::Debug> std::error::Error for SendError<T> {} + impl<T> std::error::Error for SendError<T> {} /// Error produced when receiving a change notification. - #[derive(Debug)] + #[derive(Debug, Clone)] pub struct RecvError(pub(super) ()); // ===== impl RecvError ===== @@ -152,7 +242,128 @@ pub mod error { impl std::error::Error for RecvError {} } -const CLOSED: usize = 1; +mod big_notify { + use super::*; + use crate::sync::notify::Notified; + + // To avoid contention on the lock inside the `Notify`, we store multiple + // copies of it. Then, we use either circular access or randomness to spread + // out threads over different `Notify` objects. + // + // Some simple benchmarks show that randomness performs slightly better than + // circular access (probably due to contention on `next`), so we prefer to + // use randomness when Tokio is compiled with a random number generator. + // + // When the random number generator is not available, we fall back to + // circular access. + + pub(super) struct BigNotify { + #[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))] + next: AtomicUsize, + inner: [Notify; 8], + } + + impl BigNotify { + pub(super) fn new() -> Self { + Self { + #[cfg(not(all( + not(loom), + feature = "sync", + any(feature = "rt", feature = "macros") + )))] + next: AtomicUsize::new(0), + inner: Default::default(), + } + } + + pub(super) fn notify_waiters(&self) { + for notify in &self.inner { + notify.notify_waiters(); + } + } + + /// This function implements the case where randomness is not available. + #[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))] + pub(super) fn notified(&self) -> Notified<'_> { + let i = self.next.fetch_add(1, Relaxed) % 8; + self.inner[i].notified() + } + + /// This function implements the case where randomness is available. + #[cfg(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros")))] + pub(super) fn notified(&self) -> Notified<'_> { + let i = crate::runtime::context::thread_rng_n(8) as usize; + self.inner[i].notified() + } + } +} + +use self::state::{AtomicState, Version}; +mod state { + use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::Ordering::SeqCst; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, SeqCst); + } + } +} /// Creates a new watch channel, returning the "send" and "receive" handles. /// @@ -184,9 +395,9 @@ const CLOSED: usize = 1; pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(0), + state: AtomicState::new(), ref_count_rx: AtomicUsize::new(1), - notify_rx: Notify::new(), + notify_rx: big_notify::BigNotify::new(), notify_tx: Notify::new(), }); @@ -194,13 +405,16 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { shared: shared.clone(), }; - let rx = Receiver { shared, version: 0 }; + let rx = Receiver { + shared, + version: Version::initial(), + }; (tx, rx) } impl<T> Receiver<T> { - fn from_shared(version: usize, shared: Arc<Shared<T>>) -> Self { + fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self { // No synchronization necessary as this is only used as a counter and // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); @@ -214,9 +428,29 @@ impl<T> Receiver<T> { /// [`changed`] may return immediately even if you have already seen the /// value with a call to `borrow`. /// - /// Outstanding borrows hold a read lock. This means that long lived borrows - /// could cause the send half to block. It is recommended to keep the borrow - /// as short lived as possible. + /// Outstanding borrows hold a read lock on the inner value. This means that + /// long-lived borrows could cause the producer half to block. It is recommended + /// to keep the borrow as short-lived as possible. Additionally, if you are + /// running in an environment that allows `!Send` futures, you must ensure that + /// the returned `Ref` type is never held alive across an `.await` point, + /// otherwise, it can lead to a deadlock. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx.borrow(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx.borrow(); | + /// ``` + /// </details> /// /// [`changed`]: Receiver::changed /// @@ -230,28 +464,104 @@ impl<T> Receiver<T> { /// ``` pub fn borrow(&self) -> Ref<'_, T> { let inner = self.shared.value.read().unwrap(); - Ref { inner } + + // After obtaining a read-lock no concurrent writes could occur + // and the loaded version matches that of the borrowed reference. + let new_version = self.shared.state.load().version(); + let has_changed = self.version != new_version; + + Ref { inner, has_changed } } - /// Returns a reference to the most recently sent value and mark that value + /// Returns a reference to the most recently sent value and marks that value /// as seen. /// - /// This method marks the value as seen, so [`changed`] will not return - /// immediately if the newest value is one previously returned by - /// `borrow_and_update`. + /// This method marks the current value as seen. Subsequent calls to [`changed`] + /// will not return immediately until the [`Sender`] has modified the shared + /// value again. + /// + /// Outstanding borrows hold a read lock on the inner value. This means that + /// long-lived borrows could cause the producer half to block. It is recommended + /// to keep the borrow as short-lived as possible. Additionally, if you are + /// running in an environment that allows `!Send` futures, you must ensure that + /// the returned `Ref` type is never held alive across an `.await` point, + /// otherwise, it can lead to a deadlock. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: /// - /// Outstanding borrows hold a read lock. This means that long lived borrows - /// could cause the send half to block. It is recommended to keep the borrow - /// as short lived as possible. + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx1.borrow_and_update(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx2.borrow_and_update(); | + /// ``` + /// </details> /// /// [`changed`]: Receiver::changed pub fn borrow_and_update(&mut self) -> Ref<'_, T> { let inner = self.shared.value.read().unwrap(); - self.version = self.shared.version.load(SeqCst) & !CLOSED; - Ref { inner } + + // After obtaining a read-lock no concurrent writes could occur + // and the loaded version matches that of the borrowed reference. + let new_version = self.shared.state.load().version(); + let has_changed = self.version != new_version; + + // Mark the shared value as seen by updating the version + self.version = new_version; + + Ref { inner, has_changed } + } + + /// Checks if this channel contains a message that this receiver has not yet + /// seen. The new value is not marked as seen. + /// + /// Although this method is called `has_changed`, it does not check new + /// messages for equality, so this call will return true even if the new + /// message is equal to the old message. + /// + /// Returns an error if the channel has been closed. + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// assert!(rx.has_changed().unwrap()); + /// assert_eq!(*rx.borrow_and_update(), "goodbye"); + /// + /// // The value has been marked as seen + /// assert!(!rx.has_changed().unwrap()); + /// + /// drop(tx); + /// // The `tx` handle has been dropped + /// assert!(rx.has_changed().is_err()); + /// } + /// ``` + pub fn has_changed(&self) -> Result<bool, error::RecvError> { + // Load the version from the state + let state = self.shared.state.load(); + if state.is_closed() { + // The sender has dropped. + return Err(error::RecvError(())); + } + let new_version = state.version(); + + Ok(self.version != new_version) } - /// Wait for a change notification, then mark the newest value as seen. + /// Waits for a change notification, then marks the newest value as seen. /// /// If the newest value in the channel has not yet been marked seen when /// this method is called, the method marks that value seen and returns @@ -291,21 +601,112 @@ impl<T> Receiver<T> { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { + changed_impl(&self.shared, &mut self.version).await + } + + /// Waits for a value that satisifes the provided condition. + /// + /// This method will call the provided closure whenever something is sent on + /// the channel. Once the closure returns `true`, this method will return a + /// reference to the value that was passed to the closure. + /// + /// Before `wait_for` starts waiting for changes, it will call the closure + /// on the current value. If the closure returns `true` when given the + /// current value, then `wait_for` will immediately return a reference to + /// the current value. This is the case even if the current value is already + /// considered seen. + /// + /// The watch channel only keeps track of the most recent value, so if + /// several messages are sent faster than `wait_for` is able to call the + /// closure, then it may skip some updates. Whenever the closure is called, + /// it will be called with the most recent value. + /// + /// When this function returns, the value that was passed to the closure + /// when it returned `true` will be considered seen. + /// + /// If the channel is closed, then `wait_for` will return a `RecvError`. + /// Once this happens, no more messages can ever be sent on the channel. + /// When an error is returned, it is guaranteed that the closure has been + /// called on the last value, and that it returned `false` for that value. + /// (If the closure returned `true`, then the last value would have been + /// returned instead of the error.) + /// + /// Like the `borrow` method, the returned borrow holds a read lock on the + /// inner value. This means that long-lived borrows could cause the producer + /// half to block. It is recommended to keep the borrow as short-lived as + /// possible. See the documentation of `borrow` for more information on + /// this. + /// + /// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// + /// async fn main() { + /// let (tx, _rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// // here we subscribe to a second receiver + /// // now in case of using `changed` we would have + /// // to first check the current value and then wait + /// // for changes or else `changed` would hang. + /// let mut rx2 = tx.subscribe(); + /// + /// // in place of changed we have use `wait_for` + /// // which would automatically check the current value + /// // and wait for changes until the closure returns true. + /// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok()); + /// assert_eq!(*rx2.borrow(), "goodbye"); + /// } + /// ``` + pub async fn wait_for( + &mut self, + mut f: impl FnMut(&T) -> bool, + ) -> Result<Ref<'_, T>, error::RecvError> { + let mut closed = false; loop { - // In order to avoid a race condition, we first request a notification, - // **then** check the current value's version. If a new version exists, - // the notification request is dropped. - let notified = self.shared.notify_rx.notified(); + { + let inner = self.shared.value.read().unwrap(); - if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { - return ret; + let new_version = self.shared.state.load().version(); + let has_changed = self.version != new_version; + self.version = new_version; + + if (!closed || has_changed) && f(&inner) { + return Ok(Ref { inner, has_changed }); + } } - notified.await; - // loop around again in case the wake-up was spurious + if closed { + return Err(error::RecvError(())); + } + + // Wait for the value to change. + closed = changed_impl(&self.shared, &mut self.version).await.is_err(); } } + /// Returns `true` if receivers belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(true); + /// let rx2 = rx.clone(); + /// assert!(rx.same_channel(&rx2)); + /// + /// let (tx3, rx3) = tokio::sync::watch::channel(true); + /// assert!(!rx3.same_channel(&rx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + cfg_process_driver! { pub(crate) fn try_has_changed(&mut self) -> Option<Result<(), error::RecvError>> { maybe_changed(&self.shared, &mut self.version) @@ -315,11 +716,11 @@ impl<T> Receiver<T> { fn maybe_changed<T>( shared: &Shared<T>, - version: &mut usize, + version: &mut Version, ) -> Option<Result<(), error::RecvError>> { // Load the version from the state - let state = shared.version.load(SeqCst); - let new_version = state & !CLOSED; + let state = shared.state.load(); + let new_version = state.version(); if *version != new_version { // Observe the new version and return @@ -327,7 +728,7 @@ fn maybe_changed<T>( return Some(Ok(())); } - if CLOSED == state & CLOSED { + if state.is_closed() { // All receivers have dropped. return Some(Err(error::RecvError(()))); } @@ -335,6 +736,27 @@ fn maybe_changed<T>( None } +async fn changed_impl<T>( + shared: &Shared<T>, + version: &mut Version, +) -> Result<(), error::RecvError> { + crate::trace::async_trace_leaf().await; + + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(shared, version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious + } +} + impl<T> Clone for Receiver<T> { fn clone(&self) -> Self { let version = self.version; @@ -357,19 +779,158 @@ impl<T> Drop for Receiver<T> { impl<T> Sender<T> { /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel is closed, which is the case when + /// every receiver has been dropped. It is possible to reopen the channel + /// using the [`subscribe`] method. However, when `send` fails, the value + /// isn't made available for future receivers (but returned with the + /// [`SendError`]). + /// + /// To always make a new value available for future receivers, even if no + /// receiver currently exists, one of the other send methods + /// ([`send_if_modified`], [`send_modify`], or [`send_replace`]) can be + /// used instead. + /// + /// [`subscribe`]: Sender::subscribe + /// [`SendError`]: error::SendError + /// [`send_if_modified`]: Sender::send_if_modified + /// [`send_modify`]: Sender::send_modify + /// [`send_replace`]: Sender::send_replace pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { // This is pretty much only useful as a hint anyway, so synchronization isn't critical. - if 0 == self.shared.ref_count_rx.load(Relaxed) { - return Err(error::SendError { inner: value }); + if 0 == self.receiver_count() { + return Err(error::SendError(value)); } + self.send_replace(value); + Ok(()) + } + + /// Modifies the watched value **unconditionally** in-place, + /// notifying all receivers. + /// + /// This can be useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// Prefer to use the more versatile function [`Self::send_if_modified()`] + /// if the value is only modified conditionally during the mutable borrow + /// to prevent unneeded change notifications for unmodified values. + /// + /// # Panics + /// + /// This function panics when the invocation of the `modify` closure panics. + /// No receivers are notified when panicking. All changes of the watched + /// value applied by the closure before panicking will be visible in + /// subsequent calls to `borrow`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, state_rx) = watch::channel(State { counter: 0 }); + /// state_tx.send_modify(|state| state.counter += 1); + /// assert_eq!(state_rx.borrow().counter, 1); + /// ``` + pub fn send_modify<F>(&self, modify: F) + where + F: FnOnce(&mut T), + { + self.send_if_modified(|value| { + modify(value); + true + }); + } + + /// Modifies the watched value **conditionally** in-place, + /// notifying all receivers only if modified. + /// + /// This can be useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// The `modify` closure must return `true` if the value has actually + /// been modified during the mutable borrow. It should only return `false` + /// if the value is guaranteed to be unmodified despite the mutable + /// borrow. + /// + /// Receivers are only notified if the closure returned `true`. If the + /// closure has modified the value but returned `false` this results + /// in a *silent modification*, i.e. the modified value will be visible + /// in subsequent calls to `borrow`, but receivers will not receive + /// a change notification. + /// + /// Returns the result of the closure, i.e. `true` if the value has + /// been modified and `false` otherwise. + /// + /// # Panics + /// + /// This function panics when the invocation of the `modify` closure panics. + /// No receivers are notified when panicking. All changes of the watched + /// value applied by the closure before panicking will be visible in + /// subsequent calls to `borrow`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, mut state_rx) = watch::channel(State { counter: 1 }); + /// let inc_counter_if_odd = |state: &mut State| { + /// if state.counter % 2 == 1 { + /// state.counter += 1; + /// return true; + /// } + /// false + /// }; + /// + /// assert_eq!(state_rx.borrow().counter, 1); + /// + /// assert!(!state_rx.has_changed().unwrap()); + /// assert!(state_tx.send_if_modified(inc_counter_if_odd)); + /// assert!(state_rx.has_changed().unwrap()); + /// assert_eq!(state_rx.borrow_and_update().counter, 2); + /// + /// assert!(!state_rx.has_changed().unwrap()); + /// assert!(!state_tx.send_if_modified(inc_counter_if_odd)); + /// assert!(!state_rx.has_changed().unwrap()); + /// assert_eq!(state_rx.borrow_and_update().counter, 2); + /// ``` + pub fn send_if_modified<F>(&self, modify: F) -> bool + where + F: FnOnce(&mut T) -> bool, + { { // Acquire the write lock and update the value. let mut lock = self.shared.value.write().unwrap(); - *lock = value; - // Update the version. 2 is used so that the CLOSED bit is not set. - self.shared.version.fetch_add(2, SeqCst); + // Update the value and catch possible panic inside func. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| modify(&mut lock))); + match result { + Ok(modified) => { + if !modified { + // Abort, i.e. don't notify receivers if unmodified + return false; + } + // Continue if modified + } + Err(panicked) => { + // Drop the lock to avoid poisoning it. + drop(lock); + // Forward the panic to the caller. + panic::resume_unwind(panicked); + // Unreachable + } + }; + + self.shared.state.increment_version(); // Release the write lock. // @@ -379,17 +940,42 @@ impl<T> Sender<T> { drop(lock); } - // Notify all watchers self.shared.notify_rx.notify_waiters(); - Ok(()) + true + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| mem::swap(old, &mut value)); + + value } /// Returns a reference to the most recently sent value /// - /// Outstanding borrows hold a read lock. This means that long lived borrows - /// could cause the send half to block. It is recommended to keep the borrow - /// as short lived as possible. + /// Outstanding borrows hold a read lock on the inner value. This means that + /// long-lived borrows could cause the producer half to block. It is recommended + /// to keep the borrow as short-lived as possible. Additionally, if you are + /// running in an environment that allows `!Send` futures, you must ensure that + /// the returned `Ref` type is never held alive across an `.await` point, + /// otherwise, it can lead to a deadlock. /// /// # Examples /// @@ -401,7 +987,11 @@ impl<T> Sender<T> { /// ``` pub fn borrow(&self) -> Ref<'_, T> { let inner = self.shared.value.read().unwrap(); - Ref { inner } + + // The sender/producer always sees the current version + let has_changed = false; + + Ref { inner, has_changed } } /// Checks if the channel has been closed. This happens when all receivers @@ -417,7 +1007,7 @@ impl<T> Sender<T> { /// assert!(tx.is_closed()); /// ``` pub fn is_closed(&self) -> bool { - self.shared.ref_count_rx.load(Relaxed) == 0 + self.receiver_count() == 0 } /// Completes when all receivers have dropped. @@ -450,26 +1040,86 @@ impl<T> Sender<T> { /// } /// ``` pub async fn closed(&self) { - let notified = self.shared.notify_tx.notified(); + crate::trace::async_trace_leaf().await; - if self.shared.ref_count_rx.load(Relaxed) == 0 { - return; - } + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); - notified.await; - debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); + if self.receiver_count() == 0 { + return; + } + + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } } - cfg_signal_internal! { - pub(crate) fn subscribe(&self) -> Receiver<T> { - let shared = self.shared.clone(); - let version = shared.version.load(SeqCst); + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ``` + /// use tokio::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + let version = shared.state.load().version(); - Receiver::from_shared(version, shared) - } + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) } - /// Returns the number of receivers that currently exist + /// Returns the number of receivers that currently exist. /// /// # Examples /// @@ -494,7 +1144,7 @@ impl<T> Sender<T> { impl<T> Drop for Sender<T> { fn drop(&mut self) { - self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.state.set_closed(); self.shared.notify_rx.notify_waiters(); } } |