diff options
Diffstat (limited to 'third_party/rust/tokio/src/sync')
33 files changed, 10063 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/sync/barrier.rs b/third_party/rust/tokio/src/sync/barrier.rs new file mode 100644 index 0000000000..628633493a --- /dev/null +++ b/third_party/rust/tokio/src/sync/barrier.rs @@ -0,0 +1,136 @@ +use crate::sync::watch; + +use std::sync::Mutex; + +/// A barrier enables multiple threads to synchronize the beginning of some computation. +/// +/// ``` +/// # #[tokio::main] +/// # async fn main() { +/// use tokio::sync::Barrier; +/// +/// use futures::future::join_all; +/// use std::sync::Arc; +/// +/// let mut handles = Vec::with_capacity(10); +/// let barrier = Arc::new(Barrier::new(10)); +/// for _ in 0..10 { +/// let c = barrier.clone(); +/// // The same messages will be printed together. +/// // You will NOT see any interleaving. +/// handles.push(async move { +/// println!("before wait"); +/// let wr = c.wait().await; +/// println!("after wait"); +/// wr +/// }); +/// } +/// // Will not resolve until all "before wait" messages have been printed +/// let wrs = join_all(handles).await; +/// // Exactly one barrier will resolve as the "leader" +/// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1); +/// # } +/// ``` +#[derive(Debug)] +pub struct Barrier { + state: Mutex<BarrierState>, + wait: watch::Receiver<usize>, + n: usize, +} + +#[derive(Debug)] +struct BarrierState { + waker: watch::Sender<usize>, + arrived: usize, + generation: usize, +} + +impl Barrier { + /// Creates a new barrier that can block a given number of threads. + /// + /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all + /// threads at once when the `n`th thread calls `wait`. + pub fn new(mut n: usize) -> Barrier { + let (waker, wait) = crate::sync::watch::channel(0); + + if n == 0 { + // if n is 0, it's not clear what behavior the user wants. + // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every + // .wait() immediately unblocks, so we adopt that here as well. + n = 1; + } + + Barrier { + state: Mutex::new(BarrierState { + waker, + arrived: 0, + generation: 1, + }), + n, + wait, + } + } + + /// Does not resolve until all tasks have rendezvoused here. + /// + /// Barriers are re-usable after all threads have rendezvoused once, and can + /// be used continuously. + /// + /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from + /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads + /// will receive a result that will return `false` from `is_leader`. + pub async fn wait(&self) -> BarrierWaitResult { + // NOTE: we are taking a _synchronous_ lock here. + // It is okay to do so because the critical section is fast and never yields, so it cannot + // deadlock even if another future is concurrently holding the lock. + // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than + // the asynchronous counter-parts, so we should use them where possible [citation needed]. + // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across + // a yield point, and thus marks the returned future as !Send. + let generation = { + let mut state = self.state.lock().unwrap(); + let generation = state.generation; + state.arrived += 1; + if state.arrived == self.n { + // we are the leader for this generation + // wake everyone, increment the generation, and return + state + .waker + .broadcast(state.generation) + .expect("there is at least one receiver"); + state.arrived = 0; + state.generation += 1; + return BarrierWaitResult(true); + } + + generation + }; + + // we're going to have to wait for the last of the generation to arrive + let mut wait = self.wait.clone(); + + loop { + // note that the first time through the loop, this _will_ yield a generation + // immediately, since we cloned a receiver that has never seen any values. + if wait.recv().await.expect("sender hasn't been closed") >= generation { + break; + } + } + + BarrierWaitResult(false) + } +} + +/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused. +#[derive(Debug, Clone)] +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + /// Returns `true` if this thread from wait is the "leader thread". + /// + /// Only one thread will have `true` returned from their result, all other threads will have + /// `false` returned. + pub fn is_leader(&self) -> bool { + self.0 + } +} diff --git a/third_party/rust/tokio/src/sync/batch_semaphore.rs b/third_party/rust/tokio/src/sync/batch_semaphore.rs new file mode 100644 index 0000000000..436737a670 --- /dev/null +++ b/third_party/rust/tokio/src/sync/batch_semaphore.rs @@ -0,0 +1,547 @@ +//! # Implementation Details +//! +//! The semaphore is implemented using an intrusive linked list of waiters. An +//! atomic counter tracks the number of available permits. If the semaphore does +//! not contain the required number of permits, the task attempting to acquire +//! permits places its waker at the end of a queue. When new permits are made +//! available (such as by releasing an initial acquisition), they are assigned +//! to the task at the front of the queue, waking that task if its requested +//! number of permits is met. +//! +//! Because waiters are enqueued at the back of the linked list and dequeued +//! from the front, the semaphore is fair. Tasks trying to acquire large numbers +//! of permits at a time will always be woken eventually, even if many other +//! tasks are acquiring smaller numbers of permits. This means that in a +//! use-case like tokio's read-write lock, writers will not be starved by +//! readers. +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Mutex, MutexGuard}; +use crate::util::linked_list::{self, LinkedList}; + +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::*; +use std::task::Poll::*; +use std::task::{Context, Poll, Waker}; +use std::{cmp, fmt}; + +/// An asynchronous counting semaphore which permits waiting on multiple permits at once. +pub(crate) struct Semaphore { + waiters: Mutex<Waitlist>, + /// The current number of available permits in the semaphore. + permits: AtomicUsize, +} + +struct Waitlist { + queue: LinkedList<Waiter>, + closed: bool, +} + +/// Error returned by `Semaphore::try_acquire`. +#[derive(Debug)] +pub(crate) enum TryAcquireError { + Closed, + NoPermits, +} +/// Error returned by `Semaphore::acquire`. +#[derive(Debug)] +pub(crate) struct AcquireError(()); + +pub(crate) struct Acquire<'a> { + node: Waiter, + semaphore: &'a Semaphore, + num_permits: u16, + queued: bool, +} + +/// An entry in the wait queue. +struct Waiter { + /// The current state of the waiter. + /// + /// This is either the number of remaining permits required by + /// the waiter, or a flag indicating that the waiter is not yet queued. + state: AtomicUsize, + + /// The waker to notify the task awaiting permits. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + waker: UnsafeCell<Option<Waker>>, + + /// Intrusive linked-list pointers. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + /// + /// TODO: Ideally, we would be able to use loom to enforce that + /// this isn't accessed concurrently. However, it is difficult to + /// use a `UnsafeCell` here, since the `Link` trait requires _returning_ + /// references to `Pointers`, and `UnsafeCell` requires that checked access + /// take place inside a closure. We should consider changing `Pointers` to + /// use `UnsafeCell` internally. + pointers: linked_list::Pointers<Waiter>, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Semaphore { + /// The maximum number of permits which a semaphore can hold. + /// + /// Note that this reserves three bits of flags in the permit counter, but + /// we only actually use one of them. However, the previous semaphore + /// implementation used three bits, so we will continue to reserve them to + /// avoid a breaking change if additional flags need to be aadded in the + /// future. + pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; + const CLOSED: usize = 1; + const PERMIT_SHIFT: usize = 1; + + /// Creates a new semaphore with the initial number of permits + pub(crate) fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + } + } + + /// Returns the current number of available permits + pub(crate) fn available_permits(&self) -> usize { + self.permits.load(Acquire) >> Self::PERMIT_SHIFT + } + + /// Adds `n` new permits to the semaphore. + pub(crate) fn release(&self, added: usize) { + if added == 0 { + return; + } + + // Assign permits to the wait queue + self.add_permits_locked(added, self.waiters.lock().unwrap()); + } + + /// Closes the semaphore. This prevents the semaphore from issuing new + /// permits and notifies all pending waiters. + // This will be used once the bounded MPSC is updated to use the new + // semaphore implementation. + #[allow(dead_code)] + pub(crate) fn close(&self) { + let mut waiters = self.waiters.lock().unwrap(); + // If the semaphore's permits counter has enough permits for an + // unqueued waiter to acquire all the permits it needs immediately, + // it won't touch the wait list. Therefore, we have to set a bit on + // the permit counter as well. However, we must do this while + // holding the lock --- otherwise, if we set the bit and then wait + // to acquire the lock we'll enter an inconsistent state where the + // permit counter is closed, but the wait list is not. + self.permits.fetch_or(Self::CLOSED, Release); + waiters.closed = true; + while let Some(mut waiter) = waiters.queue.pop_back() { + let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = waker { + waker.wake(); + } + } + } + + pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { + let mut curr = self.permits.load(Acquire); + let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + loop { + // Has the semaphore closed?git + if curr & Self::CLOSED > 0 { + return Err(TryAcquireError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits { + return Err(TryAcquireError::NoPermits); + } + + let next = curr - num_permits; + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => return Ok(()), + Err(actual) => curr = actual, + } + } + } + + pub(crate) fn acquire(&self, num_permits: u16) -> Acquire<'_> { + Acquire::new(self, num_permits) + } + + /// Release `rem` permits to the semaphore's wait list, starting from the + /// end of the queue. + /// + /// If `rem` exceeds the number of permits needed by the wait list, the + /// remainder are assigned back to the semaphore. + fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { + let mut wakers: [Option<Waker>; 8] = Default::default(); + let mut lock = Some(waiters); + let mut is_empty = false; + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap()); + 'inner: for slot in &mut wakers[..] { + // Was the waiter assigned enough permits to wake it? + match waiters.queue.last() { + Some(waiter) => { + if !waiter.assign_permits(&mut rem) { + break 'inner; + } + } + None => { + is_empty = true; + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + break 'inner; + } + }; + let mut waiter = waiters.queue.pop_back().unwrap(); + *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + } + + if rem > 0 && is_empty { + let permits = rem << Self::PERMIT_SHIFT; + assert!( + permits < Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + rem = 0; + } + + drop(waiters); // release the lock + + wakers + .iter_mut() + .filter_map(Option::take) + .for_each(Waker::wake); + } + + assert_eq!(rem, 0); + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: u16, + node: Pin<&mut Waiter>, + queued: bool, + ) -> Poll<Result<(), AcquireError>> { + let mut acquired = 0; + + let needed = if queued { + node.state.load(Acquire) << Self::PERMIT_SHIFT + } else { + (num_permits as usize) << Self::PERMIT_SHIFT + }; + + let mut lock = None; + // First, try to take the requested number of permits from the + // semaphore. + let mut curr = self.permits.load(Acquire); + let mut waiters = loop { + // Has the semaphore closed? + if curr & Self::CLOSED > 0 { + return Ready(Err(AcquireError::closed())); + } + + let mut remaining = 0; + let total = curr + .checked_add(acquired) + .expect("number of permits must not overflow"); + let (next, acq) = if total >= needed { + let next = curr - (needed - acquired); + (next, needed >> Self::PERMIT_SHIFT) + } else { + remaining = (needed - acquired) - curr; + (0, curr >> Self::PERMIT_SHIFT) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(self.waiters.lock().unwrap()); + } + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + acquired += acq; + if remaining == 0 { + if !queued { + return Ready(Ok(())); + } else if lock.is_none() { + break self.waiters.lock().unwrap(); + } + } + break lock.expect("lock must be acquired before waiting"); + } + Err(actual) => curr = actual, + } + }; + + if waiters.closed { + return Ready(Err(AcquireError::closed())); + } + + if node.assign_permits(&mut acquired) { + self.add_permits_locked(acquired, waiters); + return Ready(Ok(())); + } + + assert_eq!(acquired, 0); + + // Otherwise, register the waker & enqueue the node. + node.waker.with_mut(|waker| { + // Safety: the wait list is locked, so we may modify the waker. + let waker = unsafe { &mut *waker }; + // Do we need to register the new waker? + if waker + .as_ref() + .map(|waker| !waker.will_wake(cx.waker())) + .unwrap_or(true) + { + *waker = Some(cx.waker().clone()); + } + }); + + // If the waiter is not already in the wait queue, enqueue it. + if !queued { + let node = unsafe { + let node = Pin::into_inner_unchecked(node) as *mut _; + NonNull::new_unchecked(node) + }; + + waiters.queue.push_front(node); + } + + Pending + } +} + +impl fmt::Debug for Semaphore { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Semaphore") + .field("permits", &self.permits.load(Relaxed)) + .finish() + } +} + +impl Waiter { + fn new(num_permits: u16) -> Self { + Waiter { + waker: UnsafeCell::new(None), + state: AtomicUsize::new(num_permits as usize), + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + } + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize) -> bool { + let mut curr = self.state.load(Acquire); + loop { + let assign = cmp::min(curr, *n); + let next = curr - assign; + match self.state.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + *n -= assign; + return next == 0; + } + Err(actual) => curr = actual, + } + } + } +} + +impl Future for Acquire<'_> { + type Output = Result<(), AcquireError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let (node, semaphore, needed, queued) = self.project(); + match semaphore.poll_acquire(cx, needed, node, *queued) { + Pending => { + *queued = true; + Pending + } + Ready(r) => { + r?; + *queued = false; + Ready(Ok(())) + } + } + } +} + +impl<'a> Acquire<'a> { + fn new(semaphore: &'a Semaphore, num_permits: u16) -> Self { + Self { + node: Waiter::new(num_permits), + semaphore, + num_permits, + queued: false, + } + } + + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u16, &mut bool) { + fn is_unpin<T: Unpin>() {} + unsafe { + // Safety: all fields other than `node` are `Unpin` + + is_unpin::<&Semaphore>(); + is_unpin::<&mut bool>(); + is_unpin::<u16>(); + + let this = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut this.node), + &this.semaphore, + this.num_permits, + &mut this.queued, + ) + } + } +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + // If the future is completed, there is no node in the wait list, so we + // can skip acquiring the lock. + if !self.queued { + return; + } + + // This is where we ensure safety. The future is being dropped, + // which means we must ensure that the waiter entry is no longer stored + // in the linked list. + let mut waiters = match self.semaphore.waiters.lock() { + Ok(lock) => lock, + // Removing the node from the linked list is necessary to ensure + // safety. Even if the lock was poisoned, we need to make sure it is + // removed from the linked list before dropping it --- otherwise, + // the list will contain a dangling pointer to this node. + Err(e) => e.into_inner(), + }; + + // remove the entry from the list + let node = NonNull::from(&mut self.node); + // Safety: we have locked the wait list. + unsafe { waiters.queue.remove(node) }; + + let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); + if acquired_permits > 0 { + self.semaphore.add_permits_locked(acquired_permits, waiters); + } + } +} + +// Safety: the `Acquire` future is not `Sync` automatically because it contains +// a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the +// `UnsafeCell` is only accessed when the future is borrowed mutably (either in +// `poll` or in `drop`). Therefore, it is safe (although not particularly +// _useful_) for the future to be borrowed immutably across threads. +unsafe impl Sync for Acquire<'_> {} + +// ===== impl AcquireError ==== + +impl AcquireError { + fn closed() -> AcquireError { + AcquireError(()) + } +} + +impl fmt::Display for AcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "semaphore closed") + } +} + +impl std::error::Error for AcquireError {} + +// ===== impl TryAcquireError ===== + +impl TryAcquireError { + /// Returns `true` if the error was caused by a closed semaphore. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_closed(&self) -> bool { + match self { + TryAcquireError::Closed => true, + _ => false, + } + } + + /// Returns `true` if the error was caused by calling `try_acquire` on a + /// semaphore with no available permits. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_no_permits(&self) -> bool { + match self { + TryAcquireError::NoPermits => true, + _ => false, + } + } +} + +impl fmt::Display for TryAcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + } + } +} + +impl std::error::Error for TryAcquireError {} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + // XXX: ideally, we would be able to use `Pin` here, to enforce the + // invariant that list entries may not move while in the list. However, we + // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` + // type would require `Semaphore` to be generic over a lifetime. We can't + // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether + // or not they dereference to an `!Unpin` target. + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} diff --git a/third_party/rust/tokio/src/sync/broadcast.rs b/third_party/rust/tokio/src/sync/broadcast.rs new file mode 100644 index 0000000000..05a58070ee --- /dev/null +++ b/third_party/rust/tokio/src/sync/broadcast.rs @@ -0,0 +1,1046 @@ +//! A multi-producer, multi-consumer broadcast queue. Each sent value is seen by +//! all consumers. +//! +//! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] +//! values. [`Sender`] handles are clone-able, allowing concurrent send and +//! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as +//! long as `T` is also `Send` or `Sync` respectively. +//! +//! When a value is sent, **all** [`Receiver`] handles are notified and will +//! receive the value. The value is stored once inside the channel and cloned on +//! demand for each receiver. Once all receivers have received a clone of the +//! value, the value is released from the channel. +//! +//! A channel is created by calling [`channel`], specifying the maximum number +//! of messages the channel can retain at any given time. +//! +//! New [`Receiver`] handles are created by calling [`Sender::subscribe`]. The +//! returned [`Receiver`] will receive values sent **after** the call to +//! `subscribe`. +//! +//! ## Lagging +//! +//! As sent messages must be retained until **all** [`Receiver`] handles receive +//! a clone, broadcast channels are suspectible to the "slow receiver" problem. +//! In this case, all but one receiver are able to receive values at the rate +//! they are sent. Because one receiver is stalled, the channel starts to fill +//! up. +//! +//! This broadcast channel implementation handles this case by setting a hard +//! upper bound on the number of values the channel may retain at any given +//! time. This upper bound is passed to the [`channel`] function as an argument. +//! +//! If a value is sent when the channel is at capacity, the oldest value +//! currently held by the channel is released. This frees up space for the new +//! value. Any receiver that has not yet seen the released value will return +//! [`RecvError::Lagged`] the next time [`recv`] is called. +//! +//! Once [`RecvError::Lagged`] is returned, the lagging receiver's position is +//! updated to the oldest value contained by the channel. The next call to +//! [`recv`] will return this value. +//! +//! This behavior enables a receiver to detect when it has lagged so far behind +//! that data has been dropped. The caller may decide how to respond to this: +//! either by aborting its task or by tolerating lost messages and resuming +//! consumption of the channel. +//! +//! ## Closing +//! +//! When **all** [`Sender`] handles have been dropped, no new values may be +//! sent. At this point, the channel is "closed". Once a receiver has received +//! all values retained by the channel, the next call to [`recv`] will return +//! with [`RecvError::Closed`]. +//! +//! [`Sender`]: crate::sync::broadcast::Sender +//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +//! [`Receiver`]: crate::sync::broadcast::Receiver +//! [`channel`]: crate::sync::broadcast::channel +//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged +//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed +//! [`recv`]: crate::sync::broadcast::Receiver::recv +//! +//! # Examples +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! Handling lag +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx) = broadcast::channel(2); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! tx.send(30).unwrap(); +//! +//! // The receiver lagged behind +//! assert!(rx.recv().await.is_err()); +//! +//! // At this point, we can abort or continue with lost messages +//! +//! assert_eq!(20, rx.recv().await.unwrap()); +//! assert_eq!(30, rx.recv().await.unwrap()); +//! } + +use crate::loom::cell::UnsafeCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::{spin_loop_hint, AtomicBool, AtomicPtr, AtomicUsize}; +use crate::loom::sync::{Arc, Condvar, Mutex}; + +use std::fmt; +use std::mem; +use std::ptr; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; +use std::usize; + +/// Sending-half of the [`broadcast`] channel. +/// +/// May be used from many threads. Messages can be sent with +/// [`send`][Sender::send]. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Sender<T> { + shared: Arc<Shared<T>>, +} + +/// Receiving-half of the [`broadcast`] channel. +/// +/// Must not be used concurrently. Messages may be retrieved using +/// [`recv`][Receiver::recv]. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Receiver<T> { + /// State shared with all receivers and senders. + shared: Arc<Shared<T>>, + + /// Next position to read from + next: u64, + + /// Waiter state + wait: Arc<WaitNode>, +} + +/// Error returned by [`Sender::send`][Sender::send]. +/// +/// A **send** operation can only fail if there are no active receivers, +/// implying that the message could never be received. The error contains the +/// message being sent as a payload so it can be recovered. +#[derive(Debug)] +pub struct SendError<T>(pub T); + +/// An error returned from the [`recv`] function on a [`Receiver`]. +/// +/// [`recv`]: crate::sync::broadcast::Receiver::recv +/// [`Receiver`]: crate::sync::broadcast::Receiver +#[derive(Debug, PartialEq)] +pub enum RecvError { + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), +} + +/// An error returned from the [`try_recv`] function on a [`Receiver`]. +/// +/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv +/// [`Receiver`]: crate::sync::broadcast::Receiver +#[derive(Debug, PartialEq)] +pub enum TryRecvError { + /// The channel is currently empty. There are still active + /// [`Sender`][Sender] handles, so data may yet become available. + Empty, + + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind and has been forcibly disconnected. + /// Attempting to receive again will return the oldest message still + /// retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), +} + +/// Data shared between senders and receivers +struct Shared<T> { + /// slots in the channel + buffer: Box<[Slot<T>]>, + + /// Mask a position -> index + mask: usize, + + /// Tail of the queue + tail: Mutex<Tail>, + + /// Notifies a sender that the slot is unlocked + condvar: Condvar, + + /// Stack of pending waiters + wait_stack: AtomicPtr<WaitNode>, + + /// Number of outstanding Sender handles + num_tx: AtomicUsize, +} + +/// Next position to write a value +struct Tail { + /// Next position to write to + pos: u64, + + /// Number of active receivers + rx_cnt: usize, +} + +/// Slot in the buffer +struct Slot<T> { + /// Remaining number of receivers that are expected to see this value. + /// + /// When this goes to zero, the value is released. + rem: AtomicUsize, + + /// Used to lock the `write` field. + lock: AtomicUsize, + + /// The value being broadcast + /// + /// Synchronized by `state` + write: Write<T>, +} + +/// A write in the buffer +struct Write<T> { + /// Uniquely identifies this write + pos: UnsafeCell<u64>, + + /// The written value + val: UnsafeCell<Option<T>>, +} + +/// Tracks a waiting receiver +#[derive(Debug)] +struct WaitNode { + /// `true` if queued + queued: AtomicBool, + + /// Task to wake when a permit is made available. + waker: AtomicWaker, + + /// Next pointer in the stack of waiting senders. + next: UnsafeCell<*const WaitNode>, +} + +struct RecvGuard<'a, T> { + slot: &'a Slot<T>, + tail: &'a Mutex<Tail>, + condvar: &'a Condvar, +} + +/// Max number of receivers. Reserve space to lock. +const MAX_RECEIVERS: usize = usize::MAX >> 1; + +/// Create a bounded, multi-producer, multi-consumer channel where each sent +/// value is broadcasted to all active receivers. +/// +/// All data sent on [`Sender`] will become available on every active +/// [`Receiver`] in the same order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple +/// points in the process or it can be used concurrently from an `Arc`. New +/// `Receiver` handles are created by calling [`Sender::subscribe`]. +/// +/// If all [`Receiver`] handles are dropped, the `send` method will return a +/// [`SendError`]. Similarly, if all [`Sender`] handles are dropped, the [`recv`] +/// method will return a [`RecvError`]. +/// +/// [`Sender`]: crate::sync::broadcast::Sender +/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +/// [`Receiver`]: crate::sync::broadcast::Receiver +/// [`recv`]: crate::sync::broadcast::Receiver::recv +/// [`SendError`]: crate::sync::broadcast::SendError +/// [`RecvError`]: crate::sync::broadcast::RecvError +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { + assert!(capacity > 0, "capacity is empty"); + assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); + + // Round to a power of two + capacity = capacity.next_power_of_two(); + + let mut buffer = Vec::with_capacity(capacity); + + for i in 0..capacity { + buffer.push(Slot { + rem: AtomicUsize::new(0), + lock: AtomicUsize::new(0), + write: Write { + pos: UnsafeCell::new((i as u64).wrapping_sub(capacity as u64)), + val: UnsafeCell::new(None), + }, + }); + } + + let shared = Arc::new(Shared { + buffer: buffer.into_boxed_slice(), + mask: capacity - 1, + tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }), + condvar: Condvar::new(), + wait_stack: AtomicPtr::new(ptr::null_mut()), + num_tx: AtomicUsize::new(1), + }); + + let rx = Receiver { + shared: shared.clone(), + next: 0, + wait: Arc::new(WaitNode { + queued: AtomicBool::new(false), + waker: AtomicWaker::new(), + next: UnsafeCell::new(ptr::null()), + }), + }; + + let tx = Sender { shared }; + + (tx, rx) +} + +unsafe impl<T: Send> Send for Sender<T> {} +unsafe impl<T: Send> Sync for Sender<T> {} + +unsafe impl<T: Send> Send for Receiver<T> {} +unsafe impl<T: Send> Sync for Receiver<T> {} + +impl<T> Sender<T> { + /// Attempts to send a value to all active [`Receiver`] handles, returning + /// it back if it could not be sent. + /// + /// A successful send occurs when there is at least one active [`Receiver`] + /// handle. An unsuccessful send would be one where all associated + /// [`Receiver`] handles have already been dropped. + /// + /// # Return + /// + /// On success, the number of subscribed [`Receiver`] handles is returned. + /// This does not mean that this number of receivers will see the message as + /// a receiver may drop before receiving the message. + /// + /// # Note + /// + /// A return value of `Ok` **does not** mean that the sent value will be + /// observed by all or any of the active [`Receiver`] handles. [`Receiver`] + /// handles may be dropped before receiving the sent message. + /// + /// A return value of `Err` **does not** mean that future calls to `send` + /// will fail. New [`Receiver`] handles may be created by calling + /// [`subscribe`]. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + pub fn send(&self, value: T) -> Result<usize, SendError<T>> { + self.send2(Some(value)) + .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + } + + /// Creates a new [`Receiver`] handle that will receive values sent **after** + /// this call to `subscribe`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = broadcast::channel(16); + /// + /// // Will not be seen + /// tx.send(10).unwrap(); + /// + /// let mut rx = tx.subscribe(); + /// + /// tx.send(20).unwrap(); + /// + /// let value = rx.recv().await.unwrap(); + /// assert_eq!(20, value); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + + let mut tail = shared.tail.lock().unwrap(); + + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + let next = tail.pos; + + drop(tail); + + Receiver { + shared, + next, + wait: Arc::new(WaitNode { + queued: AtomicBool::new(false), + waker: AtomicWaker::new(), + next: UnsafeCell::new(ptr::null()), + }), + } + } + + /// Returns the number of active receivers + /// + /// An active receiver is a [`Receiver`] handle returned from [`channel`] or + /// [`subscribe`]. These are the handles that will receive values sent on + /// this [`Sender`]. + /// + /// # Note + /// + /// It is not guaranteed that a sent message will reach this number of + /// receivers. Active receivers may never call [`recv`] again before + /// dropping. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`Sender`]: crate::sync::broadcast::Sender + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// [`channel`]: crate::sync::broadcast::channel + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx1) = broadcast::channel(16); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = tx.subscribe(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// + /// tx.send(10).unwrap(); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + let tail = self.shared.tail.lock().unwrap(); + tail.rx_cnt + } + + fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { + let mut tail = self.shared.tail.lock().unwrap(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let slot = &self.shared.buffer[idx]; + + // Acquire the write lock + let mut prev = slot.lock.fetch_or(1, SeqCst); + + while prev & !1 != 0 { + // Concurrent readers, we must go to sleep + tail = self.shared.condvar.wait(tail).unwrap(); + + prev = slot.lock.load(SeqCst); + + if prev & 1 == 0 { + // The writer lock bit was cleared while this thread was + // sleeping. This can only happen if a newer write happened on + // this slot by another thread. Bail early as an optimization, + // there is nothing left to do. + return Ok(rem); + } + } + + if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 { + // There is a newer pending write to the same slot. + return Ok(rem); + } + + // Slot lock acquired + slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); + slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); + + // Set remaining receivers + slot.rem.store(rem, SeqCst); + + // Release the slot lock + slot.lock.store(0, SeqCst); + + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + + // Notify waiting receivers + self.notify_rx(); + + Ok(rem) + } + + fn notify_rx(&self) { + let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode; + + while !curr.is_null() { + let waiter = unsafe { Arc::from_raw(curr) }; + + // Update `curr` before toggling `queued` and waking + curr = waiter.next.with(|ptr| unsafe { *ptr }); + + // Unset queued + waiter.queued.store(false, SeqCst); + + // Wake + waiter.waker.wake(); + } + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Sender<T> { + let shared = self.shared.clone(); + shared.num_tx.fetch_add(1, SeqCst); + + Sender { shared } + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { + let _ = self.send2(None); + } + } +} + +impl<T> Receiver<T> { + /// Locks the next value if there is one. + /// + /// The caller is responsible for unlocking + fn recv_ref(&mut self, spin: bool) -> Result<RecvGuard<'_, T>, TryRecvError> { + let idx = (self.next & self.shared.mask as u64) as usize; + + // The slot holding the next value to read + let slot = &self.shared.buffer[idx]; + + // Lock the slot + if !slot.try_rx_lock() { + if spin { + while !slot.try_rx_lock() { + spin_loop_hint(); + } + } else { + return Err(TryRecvError::Empty); + } + } + + let guard = RecvGuard { + slot, + tail: &self.shared.tail, + condvar: &self.shared.condvar, + }; + + if guard.pos() != self.next { + let pos = guard.pos(); + + guard.drop_no_rem_dec(); + + if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { + return Err(TryRecvError::Empty); + } else { + let tail = self.shared.tail.lock().unwrap(); + + // `tail.pos` points to the slot the **next** send writes to. + // Because a receiver is lagging, this slot also holds the + // oldest value. To make the positions match, we subtract the + // capacity. + let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); + let missed = next.wrapping_sub(self.next); + + self.next = next; + + return Err(TryRecvError::Lagged(missed)); + } + } + + self.next = self.next.wrapping_add(1); + + Ok(guard) + } +} + +impl<T> Receiver<T> +where + T: Clone, +{ + /// Attempts to return a pending value on this receiver without awaiting. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// await on a receiver. + /// + /// Compared with [`recv`], this function has three failure cases instead of one + /// (one for closed, one for an empty buffer, one for a lagging receiver). + /// + /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`try_recv`] will return this value + /// **unless** it has been since overwritten. If there are no values to + /// receive, `Err(TryRecvError::Empty)` is returned. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// assert!(rx.try_recv().is_err()); + /// + /// tx.send(10).unwrap(); + /// + /// let value = rx.try_recv().unwrap(); + /// assert_eq!(10, value); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let guard = self.recv_ref(false)?; + guard.clone_value().ok_or(TryRecvError::Closed) + } + + #[doc(hidden)] // TODO: document + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + if let Some(value) = ok_empty(self.try_recv())? { + return Poll::Ready(Ok(value)); + } + + self.register_waker(cx.waker()); + + if let Some(value) = ok_empty(self.try_recv())? { + Poll::Ready(Ok(value)) + } else { + Poll::Pending + } + } + + /// Receives the next value for this receiver. + /// + /// Each [`Receiver`] handle will receive a clone of all values sent + /// **after** it has subscribed. + /// + /// `Err(RecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(RecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`recv`] will return this value + /// **unless** it has been since overwritten. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + /// + /// Handling lag + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// // The receiver lagged behind + /// assert!(rx.recv().await.is_err()); + /// + /// // At this point, we can abort or continue with lost messages + /// + /// assert_eq!(20, rx.recv().await.unwrap()); + /// assert_eq!(30, rx.recv().await.unwrap()); + /// } + pub async fn recv(&mut self) -> Result<T, RecvError> { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_recv(cx)).await + } + + fn register_waker(&self, cx: &Waker) { + self.wait.waker.register_by_ref(cx); + + if !self.wait.queued.load(SeqCst) { + // Set `queued` before queuing. + self.wait.queued.store(true, SeqCst); + + let mut curr = self.shared.wait_stack.load(SeqCst); + + // The ref count is decremented in `notify_rx` when all nodes are + // removed from the waiter stack. + let node = Arc::into_raw(self.wait.clone()) as *mut _; + + loop { + // Safety: `queued == false` means the caller has exclusive + // access to `self.wait.next`. + self.wait.next.with_mut(|ptr| unsafe { *ptr = curr }); + + let res = self + .shared + .wait_stack + .compare_exchange(curr, node, SeqCst, SeqCst); + + match res { + Ok(_) => return, + Err(actual) => curr = actual, + } + } + } + } +} + +#[cfg(feature = "stream")] +impl<T> crate::stream::Stream for Receiver<T> +where + T: Clone, +{ + type Item = Result<T, RecvError>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<T, RecvError>>> { + self.poll_recv(cx).map(|v| match v { + Ok(v) => Some(Ok(v)), + lag @ Err(RecvError::Lagged(_)) => Some(lag), + Err(RecvError::Closed) => None, + }) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + let mut tail = self.shared.tail.lock().unwrap(); + + tail.rx_cnt -= 1; + let until = tail.pos; + + drop(tail); + + while self.next != until { + match self.recv_ref(true) { + // Ignore the value + Ok(_) => {} + // The channel is closed + Err(TryRecvError::Closed) => break, + // Ignore lagging, we will catch up + Err(TryRecvError::Lagged(..)) => {} + // Can't be empty + Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"), + } + } + } +} + +impl<T> Drop for Shared<T> { + fn drop(&mut self) { + // Clear the wait stack + let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode); + + while !curr.is_null() { + let waiter = unsafe { Arc::from_raw(curr) }; + curr = waiter.next.with(|ptr| unsafe { *ptr }); + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Sender") + } +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Receiver") + } +} + +impl<T> Slot<T> { + /// Tries to lock the slot for a receiver. If `false`, then a sender holds the + /// lock and the calling task will be notified once the sender has released + /// the lock. + fn try_rx_lock(&self) -> bool { + let mut curr = self.lock.load(SeqCst); + + loop { + if curr & 1 == 1 { + // Locked by sender + return false; + } + + // Only increment (by 2) if the LSB "lock" bit is not set. + let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst); + + match res { + Ok(_) => return true, + Err(actual) => curr = actual, + } + } + } + + fn rx_unlock(&self, tail: &Mutex<Tail>, condvar: &Condvar, rem_dec: bool) { + if rem_dec { + // Decrement the remaining counter + if 1 == self.rem.fetch_sub(1, SeqCst) { + // Last receiver, drop the value + self.write.val.with_mut(|ptr| unsafe { *ptr = None }); + } + } + + if 1 == self.lock.fetch_sub(2, SeqCst) - 2 { + // First acquire the lock to make sure our sender is waiting on the + // condition variable, otherwise the notification could be lost. + mem::drop(tail.lock().unwrap()); + // Wake up senders + condvar.notify_all(); + } + } +} + +impl<'a, T> RecvGuard<'a, T> { + fn pos(&self) -> u64 { + self.slot.write.pos.with(|ptr| unsafe { *ptr }) + } + + fn clone_value(&self) -> Option<T> + where + T: Clone, + { + self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() }) + } + + fn drop_no_rem_dec(self) { + self.slot.rx_unlock(self.tail, self.condvar, false); + + mem::forget(self); + } +} + +impl<'a, T> Drop for RecvGuard<'a, T> { + fn drop(&mut self) { + self.slot.rx_unlock(self.tail, self.condvar, true) + } +} + +fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> { + match res { + Ok(value) => Ok(Some(value)), + Err(TryRecvError::Empty) => Ok(None), + Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)), + Err(TryRecvError::Closed) => Err(RecvError::Closed), + } +} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvError::Closed => write!(f, "channel closed"), + RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } +} + +impl std::error::Error for RecvError {} + +impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(f, "channel empty"), + TryRecvError::Closed => write!(f, "channel closed"), + TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } +} + +impl std::error::Error for TryRecvError {} diff --git a/third_party/rust/tokio/src/sync/mod.rs b/third_party/rust/tokio/src/sync/mod.rs new file mode 100644 index 0000000000..0607f78ad4 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mod.rs @@ -0,0 +1,472 @@ +#![cfg_attr(loom, allow(dead_code, unreachable_pub, unused_imports))] + +//! Synchronization primitives for use in asynchronous contexts. +//! +//! Tokio programs tend to be organized as a set of [tasks] where each task +//! operates independently and may be executed on separate physical threads. The +//! synchronization primitives provided in this module permit these independent +//! tasks to communicate together. +//! +//! [tasks]: crate::task +//! +//! # Message passing +//! +//! The most common form of synchronization in a Tokio program is message +//! passing. Two tasks operate independently and send messages to each other to +//! synchronize. Doing so has the advantage of avoiding shared state. +//! +//! Message passing is implemented using channels. A channel supports sending a +//! message from one producer task to one or more consumer tasks. There are a +//! few flavors of channels provided by Tokio. Each channel flavor supports +//! different message passing patterns. When a channel supports multiple +//! producers, many separate tasks may **send** messages. When a channel +//! supports muliple consumers, many different separate tasks may **receive** +//! messages. +//! +//! Tokio provides many different channel flavors as different message passing +//! patterns are best handled with different implementations. +//! +//! ## `oneshot` channel +//! +//! The [`oneshot` channel][oneshot] supports sending a **single** value from a +//! single producer to a single consumer. This channel is usually used to send +//! the result of a computation to a waiter. +//! +//! **Example:** using a `oneshot` channel to receive the result of a +//! computation. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! async fn some_computation() -> String { +//! "represents the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! let res = some_computation().await; +//! tx.send(res).unwrap(); +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = rx.await.unwrap(); +//! } +//! ``` +//! +//! Note, if the task produces the the computation result as its final action +//! before terminating, the [`JoinHandle`] can be used to receive the +//! computation result instead of allocating resources for the `oneshot` +//! channel. Awaiting on [`JoinHandle`] returns `Result`. If the task panics, +//! the `Joinhandle` yields `Err` with the panic cause. +//! +//! **Example:** +//! +//! ``` +//! async fn some_computation() -> String { +//! "the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let join_handle = tokio::spawn(async move { +//! some_computation().await +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = join_handle.await.unwrap(); +//! } +//! ``` +//! +//! [`JoinHandle`]: crate::task::JoinHandle +//! +//! ## `mpsc` channel +//! +//! The [`mpsc` channel][mpsc] supports sending **many** values from **many** +//! producers to a single consumer. This channel is often used to send work to a +//! task or to receive the result of many computations. +//! +//! **Example:** using an mpsc to incrementally stream the results of a series +//! of computations. +//! +//! ``` +//! use tokio::sync::mpsc; +//! +//! async fn some_computation(input: u32) -> String { +//! format!("the result of computation {}", input) +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (mut tx, mut rx) = mpsc::channel(100); +//! +//! tokio::spawn(async move { +//! for i in 0..10 { +//! let res = some_computation(i).await; +//! tx.send(res).await.unwrap(); +//! } +//! }); +//! +//! while let Some(res) = rx.recv().await { +//! println!("got = {}", res); +//! } +//! } +//! ``` +//! +//! The argument to `mpsc::channel` is the channel capacity. This is the maximum +//! number of values that can be stored in the channel pending receipt at any +//! given time. Properly setting this value is key in implementing robust +//! programs as the channel capacity plays a critical part in handling back +//! pressure. +//! +//! A common concurrency pattern for resource management is to spawn a task +//! dedicated to managing that resource and using message passing betwen other +//! tasks to interact with the resource. The resource may be anything that may +//! not be concurrently used. Some examples include a socket and program state. +//! For example, if multiple tasks need to send data over a single socket, spawn +//! a task to manage the socket and use a channel to synchronize. +//! +//! **Example:** sending data from many tasks over a single socket using message +//! passing. +//! +//! ```no_run +//! use tokio::io::{self, AsyncWriteExt}; +//! use tokio::net::TcpStream; +//! use tokio::sync::mpsc; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let mut socket = TcpStream::connect("www.example.com:1234").await?; +//! let (tx, mut rx) = mpsc::channel(100); +//! +//! for _ in 0..10 { +//! // Each task needs its own `tx` handle. This is done by cloning the +//! // original handle. +//! let mut tx = tx.clone(); +//! +//! tokio::spawn(async move { +//! tx.send(&b"data to write"[..]).await.unwrap(); +//! }); +//! } +//! +//! // The `rx` half of the channel returns `None` once **all** `tx` clones +//! // drop. To ensure `None` is returned, drop the handle owned by the +//! // current task. If this `tx` handle is not dropped, there will always +//! // be a single outstanding `tx` handle. +//! drop(tx); +//! +//! while let Some(res) = rx.recv().await { +//! socket.write_all(res).await?; +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! The [`mpsc`][mpsc] and [`oneshot`][oneshot] channels can be combined to +//! provide a request / response type synchronization pattern with a shared +//! resource. A task is spawned to synchronize a resource and waits on commands +//! received on a [`mpsc`][mpsc] channel. Each command includes a +//! [`oneshot`][oneshot] `Sender` on which the result of the command is sent. +//! +//! **Example:** use a task to synchronize a `u64` counter. Each task sends an +//! "fetch and increment" command. The counter value **before** the increment is +//! sent over the provided `oneshot` channel. +//! +//! ``` +//! use tokio::sync::{oneshot, mpsc}; +//! use Command::Increment; +//! +//! enum Command { +//! Increment, +//! // Other commands can be added here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (cmd_tx, mut cmd_rx) = mpsc::channel::<(Command, oneshot::Sender<u64>)>(100); +//! +//! // Spawn a task to manage the counter +//! tokio::spawn(async move { +//! let mut counter: u64 = 0; +//! +//! while let Some((cmd, response)) = cmd_rx.recv().await { +//! match cmd { +//! Increment => { +//! let prev = counter; +//! counter += 1; +//! response.send(prev).unwrap(); +//! } +//! } +//! } +//! }); +//! +//! let mut join_handles = vec![]; +//! +//! // Spawn tasks that will send the increment command. +//! for _ in 0..10 { +//! let mut cmd_tx = cmd_tx.clone(); +//! +//! join_handles.push(tokio::spawn(async move { +//! let (resp_tx, resp_rx) = oneshot::channel(); +//! +//! cmd_tx.send((Increment, resp_tx)).await.ok().unwrap(); +//! let res = resp_rx.await.unwrap(); +//! +//! println!("previous value = {}", res); +//! })); +//! } +//! +//! // Wait for all tasks to complete +//! for join_handle in join_handles.drain(..) { +//! join_handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! ## `broadcast` channel +//! +//! The [`broadcast` channel][broadcast] supports sending **many** values from +//! **many** producers to **many** consumers. Each consumer will receive +//! **each** value. This channel can be used to implement "fan out" style +//! patterns common with pub / sub or "chat" systems. +//! +//! This channel tends to be used less often than `oneshot` and `mpsc` but still +//! has its use cases. +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! ## `watch` channel +//! +//! The [`watch` channel][watch] supports sending **many** values from a +//! **single** producer to **many** consumers. However, only the **most recent** +//! value is stored in the channel. Consumers are notified when a new value is +//! sent, but there is no guarantee that consumers will see **all** values. +//! +//! The [`watch` channel] is similar to a [`broadcast` channel] with capacity 1. +//! +//! Use cases for the [`watch` channel] include broadcasting configuration +//! changes or signalling program state changes, such as transitioning to +//! shutdown. +//! +//! **Example:** use a `watch` channel to notify tasks of configuration changes. +//! In this example, a configuration file is checked periodically. When the file +//! changes, the configuration changes are signalled to consumers. +//! +//! ``` +//! use tokio::sync::watch; +//! use tokio::time::{self, Duration, Instant}; +//! +//! use std::io; +//! +//! #[derive(Debug, Clone, Eq, PartialEq)] +//! struct Config { +//! timeout: Duration, +//! } +//! +//! impl Config { +//! async fn load_from_file() -> io::Result<Config> { +//! // file loading and deserialization logic here +//! # Ok(Config { timeout: Duration::from_secs(1) }) +//! } +//! } +//! +//! async fn my_async_operation() { +//! // Do something here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! // Load initial configuration value +//! let mut config = Config::load_from_file().await.unwrap(); +//! +//! // Create the watch channel, initialized with the loaded configuration +//! let (tx, rx) = watch::channel(config.clone()); +//! +//! // Spawn a task to monitor the file. +//! tokio::spawn(async move { +//! loop { +//! // Wait 10 seconds between checks +//! time::delay_for(Duration::from_secs(10)).await; +//! +//! // Load the configuration file +//! let new_config = Config::load_from_file().await.unwrap(); +//! +//! // If the configuration changed, send the new config value +//! // on the watch channel. +//! if new_config != config { +//! tx.broadcast(new_config.clone()).unwrap(); +//! config = new_config; +//! } +//! } +//! }); +//! +//! let mut handles = vec![]; +//! +//! // Spawn tasks that runs the async operation for at most `timeout`. If +//! // the timeout elapses, restart the operation. +//! // +//! // The task simultaneously watches the `Config` for changes. When the +//! // timeout duration changes, the timeout is updated without restarting +//! // the in-flight operation. +//! for _ in 0..5 { +//! // Clone a config watch handle for use in this task +//! let mut rx = rx.clone(); +//! +//! let handle = tokio::spawn(async move { +//! // Start the initial operation and pin the future to the stack. +//! // Pinning to the stack is required to resume the operation +//! // across multiple calls to `select!` +//! let op = my_async_operation(); +//! tokio::pin!(op); +//! +//! // Receive the **initial** configuration value. As this is the +//! // first time the config is received from the watch, it will +//! // always complete immediatedly. +//! let mut conf = rx.recv().await.unwrap(); +//! +//! let mut op_start = Instant::now(); +//! let mut delay = time::delay_until(op_start + conf.timeout); +//! +//! loop { +//! tokio::select! { +//! _ = &mut delay => { +//! // The operation elapsed. Restart it +//! op.set(my_async_operation()); +//! +//! // Track the new start time +//! op_start = Instant::now(); +//! +//! // Restart the timeout +//! delay = time::delay_until(op_start + conf.timeout); +//! } +//! new_conf = rx.recv() => { +//! conf = new_conf.unwrap(); +//! +//! // The configuration has been updated. Update the +//! // `delay` using the new `timeout` value. +//! delay.reset(op_start + conf.timeout); +//! } +//! _ = &mut op => { +//! // The operation completed! +//! return +//! } +//! } +//! } +//! }); +//! +//! handles.push(handle); +//! } +//! +//! for handle in handles.drain(..) { +//! handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! # State synchronization +//! +//! The remaining synchronization primitives focus on synchronizing state. +//! These are asynchronous equivalents to versions provided by `std`. They +//! operate in a similar way as their `std` counterparts parts but will wait +//! asynchronously instead of blocking the thread. +//! +//! * [`Barrier`][Barrier] Ensures multiple tasks will wait for each other to +//! reach a point in the program, before continuing execution all together. +//! +//! * [`Mutex`][Mutex] Mutual Exclusion mechanism, which ensures that at most +//! one thread at a time is able to access some data. +//! +//! * [`Notify`][Notify] Basic task notification. `Notify` supports notifying a +//! receiving task without sending data. In this case, the task wakes up and +//! resumes processing. +//! +//! * [`RwLock`][RwLock] Provides a mutual exclusion mechanism which allows +//! multiple readers at the same time, while allowing only one writer at a +//! time. In some cases, this can be more efficient than a mutex. +//! +//! * [`Semaphore`][Semaphore] Limits the amount of concurrency. A semaphore +//! holds a number of permits, which tasks may request in order to enter a +//! critical section. Semaphores are useful for implementing limiting of +//! bounding of any kind. + +cfg_sync! { + mod barrier; + pub use barrier::{Barrier, BarrierWaitResult}; + + pub mod broadcast; + + pub mod mpsc; + + mod mutex; + pub use mutex::{Mutex, MutexGuard}; + + mod notify; + pub use notify::Notify; + + pub mod oneshot; + + pub(crate) mod batch_semaphore; + pub(crate) mod semaphore_ll; + mod semaphore; + pub use semaphore::{Semaphore, SemaphorePermit}; + + mod rwlock; + pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + + mod task; + pub(crate) use task::AtomicWaker; + + pub mod watch; +} + +cfg_not_sync! { + cfg_atomic_waker_impl! { + mod task; + pub(crate) use task::AtomicWaker; + } + + #[cfg(any( + feature = "rt-core", + feature = "process", + feature = "signal"))] + pub(crate) mod oneshot; + + cfg_signal! { + pub(crate) mod mpsc; + pub(crate) mod semaphore_ll; + } +} + +/// Unit tests +#[cfg(test)] +mod tests; diff --git a/third_party/rust/tokio/src/sync/mpsc/block.rs b/third_party/rust/tokio/src/sync/mpsc/block.rs new file mode 100644 index 0000000000..7bf161967b --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/block.rs @@ -0,0 +1,387 @@ +use crate::loom::{ + cell::UnsafeCell, + sync::atomic::{AtomicPtr, AtomicUsize}, + thread, +}; + +use std::mem::MaybeUninit; +use std::ops; +use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Release}; + +/// A block in a linked list. +/// +/// Each block in the list can hold up to `BLOCK_CAP` messages. +pub(crate) struct Block<T> { + /// The start index of this block. + /// + /// Slots in this block have indices in `start_index .. start_index + BLOCK_CAP`. + start_index: usize, + + /// The next block in the linked list. + next: AtomicPtr<Block<T>>, + + /// Bitfield tracking slots that are ready to have their values consumed. + ready_slots: AtomicUsize, + + /// The observed `tail_position` value *after* the block has been passed by + /// `block_tail`. + observed_tail_position: UnsafeCell<usize>, + + /// Array containing values pushed into the block. Values are stored in a + /// continuous array in order to improve cache line behavior when reading. + /// The values must be manually dropped. + values: Values<T>, +} + +pub(crate) enum Read<T> { + Value(T), + Closed, +} + +struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); + +use super::BLOCK_CAP; + +/// Masks an index to get the block identifier +const BLOCK_MASK: usize = !(BLOCK_CAP - 1); + +/// Masks an index to get the value offset in a block. +const SLOT_MASK: usize = BLOCK_CAP - 1; + +/// Flag tracking that a block has gone through the sender's release routine. +/// +/// When this is set, the receiver may consider freeing the block. +const RELEASED: usize = 1 << BLOCK_CAP; + +/// Flag tracking all senders dropped. +/// +/// When this flag is set, the send half of the channel has closed. +const TX_CLOSED: usize = RELEASED << 1; + +/// Mask covering all bits used to track slot readiness. +const READY_MASK: usize = RELEASED - 1; + +/// Returns the index of the first slot in the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn start_index(slot_index: usize) -> usize { + BLOCK_MASK & slot_index +} + +/// Returns the offset into the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn offset(slot_index: usize) -> usize { + SLOT_MASK & slot_index +} + +impl<T> Block<T> { + pub(crate) fn new(start_index: usize) -> Block<T> { + Block { + // The absolute index in the channel of the first slot in the block. + start_index, + + // Pointer to the next block in the linked list. + next: AtomicPtr::new(ptr::null_mut()), + + ready_slots: AtomicUsize::new(0), + + observed_tail_position: UnsafeCell::new(0), + + // Value storage + values: unsafe { Values::uninitialized() }, + } + } + + /// Returns `true` if the block matches the given index + pub(crate) fn is_at_index(&self, index: usize) -> bool { + debug_assert!(offset(index) == 0); + self.start_index == index + } + + /// Returns the number of blocks between `self` and the block at the + /// specified index. + /// + /// `start_index` must represent a block *after* `self`. + pub(crate) fn distance(&self, other_index: usize) -> usize { + debug_assert!(offset(other_index) == 0); + other_index.wrapping_sub(self.start_index) / BLOCK_CAP + } + + /// Reads the value at the given offset. + /// + /// Returns `None` if the slot is empty. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * No concurrent access to the slot. + pub(crate) unsafe fn read(&self, slot_index: usize) -> Option<Read<T>> { + let offset = offset(slot_index); + + let ready_bits = self.ready_slots.load(Acquire); + + if !is_ready(ready_bits, offset) { + if is_tx_closed(ready_bits) { + return Some(Read::Closed); + } + + return None; + } + + // Get the value + let value = self.values[offset].with(|ptr| ptr::read(ptr)); + + Some(Read::Value(value.assume_init())) + } + + /// Writes a value to the block at the given offset. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The slot is empty. + /// * No concurrent access to the slot. + pub(crate) unsafe fn write(&self, slot_index: usize, value: T) { + // Get the offset into the block + let slot_offset = offset(slot_index); + + self.values[slot_offset].with_mut(|ptr| { + ptr::write(ptr, MaybeUninit::new(value)); + }); + + // Release the value. After this point, the slot ref may no longer + // be used. It is possible for the receiver to free the memory at + // any point. + self.set_ready(slot_offset); + } + + /// Signal to the receiver that the sender half of the list is closed. + pub(crate) unsafe fn tx_close(&self) { + self.ready_slots.fetch_or(TX_CLOSED, Release); + } + + /// Resets the block to a blank state. This enables reusing blocks in the + /// channel. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * All slots are empty. + /// * The caller holds a unique pointer to the block. + pub(crate) unsafe fn reclaim(&mut self) { + self.start_index = 0; + self.next = AtomicPtr::new(ptr::null_mut()); + self.ready_slots = AtomicUsize::new(0); + } + + /// Releases the block to the rx half for freeing. + /// + /// This function is called by the tx half once it can be guaranteed that no + /// more senders will attempt to access the block. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The block will no longer be accessed by any sender. + pub(crate) unsafe fn tx_release(&self, tail_position: usize) { + // Track the observed tail_position. Any sender targetting a greater + // tail_position is guaranteed to not access this block. + self.observed_tail_position + .with_mut(|ptr| *ptr = tail_position); + + // Set the released bit, signalling to the receiver that it is safe to + // free the block's memory as soon as all slots **prior** to + // `observed_tail_position` have been filled. + self.ready_slots.fetch_or(RELEASED, Release); + } + + /// Mark a slot as ready + fn set_ready(&self, slot: usize) { + let mask = 1 << slot; + self.ready_slots.fetch_or(mask, Release); + } + + /// Returns `true` when all slots have their `ready` bits set. + /// + /// This indicates that the block is in its final state and will no longer + /// be mutated. + /// + /// # Implementation + /// + /// The implementation walks each slot checking the `ready` flag. It might + /// be that it would make more sense to coalesce ready flags as bits in a + /// single atomic cell. However, this could have negative impact on cache + /// behavior as there would be many more mutations to a single slot. + pub(crate) fn is_final(&self) -> bool { + self.ready_slots.load(Acquire) & READY_MASK == READY_MASK + } + + /// Returns the `observed_tail_position` value, if set + pub(crate) fn observed_tail_position(&self) -> Option<usize> { + if 0 == RELEASED & self.ready_slots.load(Acquire) { + None + } else { + Some(self.observed_tail_position.with(|ptr| unsafe { *ptr })) + } + } + + /// Loads the next block + pub(crate) fn load_next(&self, ordering: Ordering) -> Option<NonNull<Block<T>>> { + let ret = NonNull::new(self.next.load(ordering)); + + debug_assert!(unsafe { + ret.map(|block| block.as_ref().start_index == self.start_index.wrapping_add(BLOCK_CAP)) + .unwrap_or(true) + }); + + ret + } + + /// Pushes `block` as the next block in the link. + /// + /// Returns Ok if successful, otherwise, a pointer to the next block in + /// the list is returned. + /// + /// This requires that the next pointer is null. + /// + /// # Ordering + /// + /// This performs a compare-and-swap on `next` using AcqRel ordering. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * `block` is not freed until it has been removed from the list. + pub(crate) unsafe fn try_push( + &self, + block: &mut NonNull<Block<T>>, + ordering: Ordering, + ) -> Result<(), NonNull<Block<T>>> { + block.as_mut().start_index = self.start_index.wrapping_add(BLOCK_CAP); + + let next_ptr = self + .next + .compare_and_swap(ptr::null_mut(), block.as_ptr(), ordering); + + match NonNull::new(next_ptr) { + Some(next_ptr) => Err(next_ptr), + None => Ok(()), + } + } + + /// Grows the `Block` linked list by allocating and appending a new block. + /// + /// The next block in the linked list is returned. This may or may not be + /// the one allocated by the function call. + /// + /// # Implementation + /// + /// It is assumed that `self.next` is null. A new block is allocated with + /// `start_index` set to be the next block. A compare-and-swap is performed + /// with AcqRel memory ordering. If the compare-and-swap is successful, the + /// newly allocated block is released to other threads walking the block + /// linked list. If the compare-and-swap fails, the current thread acquires + /// the next block in the linked list, allowing the current thread to access + /// the slots. + pub(crate) fn grow(&self) -> NonNull<Block<T>> { + // Create the new block. It is assumed that the block will become the + // next one after `&self`. If this turns out to not be the case, + // `start_index` is updated accordingly. + let new_block = Box::new(Block::new(self.start_index + BLOCK_CAP)); + + let mut new_block = unsafe { NonNull::new_unchecked(Box::into_raw(new_block)) }; + + // Attempt to store the block. The first compare-and-swap attempt is + // "unrolled" due to minor differences in logic + // + // `AcqRel` is used as the ordering **only** when attempting the + // compare-and-swap on self.next. + // + // If the compare-and-swap fails, then the actual value of the cell is + // returned from this function and accessed by the caller. Given this, + // the memory must be acquired. + // + // `Release` ensures that the newly allocated block is available to + // other threads acquiring the next pointer. + let next = NonNull::new(self.next.compare_and_swap( + ptr::null_mut(), + new_block.as_ptr(), + AcqRel, + )); + + let next = match next { + Some(next) => next, + None => { + // The compare-and-swap succeeded and the newly allocated block + // is successfully pushed. + return new_block; + } + }; + + // There already is a next block in the linked list. The newly allocated + // block could be dropped and the discovered next block returned; + // however, that would be wasteful. Instead, the linked list is walked + // by repeatedly attempting to compare-and-swap the pointer into the + // `next` register until the compare-and-swap succeed. + // + // Care is taken to update new_block's start_index field as appropriate. + + let mut curr = next; + + // TODO: Should this iteration be capped? + loop { + let actual = unsafe { curr.as_ref().try_push(&mut new_block, AcqRel) }; + + curr = match actual { + Ok(_) => { + return next; + } + Err(curr) => curr, + }; + + // When running outside of loom, this calls `spin_loop_hint`. + thread::yield_now(); + } + } +} + +/// Returns `true` if the specificed slot has a value ready to be consumed. +fn is_ready(bits: usize, slot: usize) -> bool { + let mask = 1 << slot; + mask == mask & bits +} + +/// Returns `true` if the closed flag has been set. +fn is_tx_closed(bits: usize) -> bool { + TX_CLOSED == bits & TX_CLOSED +} + +impl<T> Values<T> { + unsafe fn uninitialized() -> Values<T> { + let mut vals = MaybeUninit::uninit(); + + // When fuzzing, `UnsafeCell` needs to be initialized. + if_loom! { + let p = vals.as_mut_ptr() as *mut UnsafeCell<MaybeUninit<T>>; + for i in 0..BLOCK_CAP { + p.add(i) + .write(UnsafeCell::new(MaybeUninit::uninit())); + } + } + + Values(vals.assume_init()) + } +} + +impl<T> ops::Index<usize> for Values<T> { + type Output = UnsafeCell<MaybeUninit<T>>; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/bounded.rs b/third_party/rust/tokio/src/sync/mpsc/bounded.rs new file mode 100644 index 0000000000..afca8c524d --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/bounded.rs @@ -0,0 +1,479 @@ +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError}; +use crate::sync::semaphore_ll as semaphore; + +cfg_time! { + use crate::sync::mpsc::error::SendTimeoutError; + use crate::time::Duration; +} + +use std::fmt; +use std::task::{Context, Poll}; + +/// Send values to the associated `Receiver`. +/// +/// Instances are created by the [`channel`](channel) function. +pub struct Sender<T> { + chan: chan::Tx<T, Semaphore>, +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +/// Receive values from the associated `Sender`. +/// +/// Instances are created by the [`channel`](channel) function. +pub struct Receiver<T> { + /// The channel receiver + chan: chan::Rx<T, Semaphore>, +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + +/// Creates a bounded mpsc channel for communicating between asynchronous tasks, +/// returning the sender/receiver halves. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is supported. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return a `RecvError`. +/// +/// # Examples +/// +/// ```rust +/// use tokio::sync::mpsc; +/// +/// #[tokio::main] +/// async fn main() { +/// let (mut tx, mut rx) = mpsc::channel(100); +/// +/// tokio::spawn(async move { +/// for i in 0..10 { +/// if let Err(_) = tx.send(i).await { +/// println!("receiver dropped"); +/// return; +/// } +/// } +/// }); +/// +/// while let Some(i) = rx.recv().await { +/// println!("got = {}", i); +/// } +/// } +/// ``` +pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) { + assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); + let semaphore = (semaphore::Semaphore::new(buffer), buffer); + let (tx, rx) = chan::channel(semaphore); + + let tx = Sender::new(tx); + let rx = Receiver::new(rx); + + (tx, rx) +} + +/// Channel semaphore is a tuple of the semaphore implementation and a `usize` +/// representing the channel bound. +type Semaphore = (semaphore::Semaphore, usize); + +impl<T> Receiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> Receiver<T> { + Receiver { chan } + } + + /// Receives the next value for this receiver. + /// + /// `None` is returned when all `Sender` halves have dropped, indicating + /// that no further values can be sent on the channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = mpsc::channel(100); + /// + /// tokio::spawn(async move { + /// tx.send("hello").await.unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// tx.send("world").await.unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_recv(cx)).await + } + + #[doc(hidden)] // TODO: document + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } + + /// Attempts to return a pending value on this receiver without blocking. + /// + /// This method will never block the caller in order to wait for data to + /// become available. Instead, this will always return immediately with + /// a possible option of pending data on the channel. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// block on a receiver. + /// + /// Compared with recv, this function has two failure cases instead of + /// one (one for disconnection, one for an empty buffer). + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Closes the receiving half of a channel, without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + pub fn close(&mut self) { + self.chan.close(); + } +} + +impl<T> Unpin for Receiver<T> {} + +cfg_stream! { + impl<T> crate::stream::Stream for Receiver<T> { + type Item = T; + + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.poll_recv(cx) + } + } +} + +impl<T> Sender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> Sender<T> { + Sender { chan } + } + + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Examples + /// + /// In the following example, each call to `send` will block until the + /// previously sent value was received. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(_) = tx.send(i).await { + /// println!("receiver dropped"); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// } + /// } + /// ``` + pub async fn send(&mut self, value: T) -> Result<(), SendError<T>> { + use crate::future::poll_fn; + + if poll_fn(|cx| self.poll_ready(cx)).await.is_err() { + return Err(SendError(value)); + } + + match self.try_send(value) { + Ok(()) => Ok(()), + Err(TrySendError::Full(_)) => unreachable!(), + Err(TrySendError::Closed(value)) => Err(SendError(value)), + } + } + + /// Attempts to immediately send a message on this `Sender` + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + /// + /// This function may be paired with [`poll_ready`] in order to wait for + /// channel capacity before trying to send a value. + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`channel`], then an + /// error is returned. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`send`]: Sender::send + /// [`poll_ready`]: Sender::poll_ready + /// [`channel`]: channel + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create a channel with buffer size 1 + /// let (mut tx1, mut rx) = mpsc::channel(1); + /// let mut tx2 = tx1.clone(); + /// + /// tokio::spawn(async move { + /// tx1.send(1).await.unwrap(); + /// tx1.send(2).await.unwrap(); + /// // task waits until the receiver receives a value. + /// }); + /// + /// tokio::spawn(async move { + /// // This will return an error and send + /// // no message if the buffer is full + /// let _ = tx2.try_send(3); + /// }); + /// + /// let mut msg; + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// // Third message may have never been sent + /// match rx.recv().await { + /// Some(msg) => println!("message {} received", msg), + /// None => println!("the third message was never sent"), + /// } + /// } + /// ``` + pub fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { + self.chan.try_send(message)?; + Ok(()) + } + + /// Sends a value, waiting until there is capacity, but only for a limited time. + /// + /// Shares the same success and error conditions as [`send`], adding one more + /// condition for an unsuccessful send, which is when the provided timeout has + /// elapsed, and there is no capacity available. + /// + /// [`send`]: Sender::send + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] having been dropped, + /// the function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Examples + /// + /// In the following example, each call to `send_timeout` will block until the + /// previously sent value was received, unless the timeout has elapsed. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// use tokio::time::{delay_for, Duration}; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(e) = tx.send_timeout(i, Duration::from_millis(100)).await { + /// println!("send error: #{:?}", e); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// delay_for(Duration::from_millis(200)).await; + /// } + /// } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + pub async fn send_timeout( + &mut self, + value: T, + timeout: Duration, + ) -> Result<(), SendTimeoutError<T>> { + use crate::future::poll_fn; + + match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await { + Err(_) => { + return Err(SendTimeoutError::Timeout(value)); + } + Ok(Err(_)) => { + return Err(SendTimeoutError::Closed(value)); + } + Ok(_) => {} + } + + match self.try_send(value) { + Ok(()) => Ok(()), + Err(TrySendError::Full(_)) => unreachable!(), + Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)), + } + } + + /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item. + /// + /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a + /// slot becomes available. + /// + /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless + /// the channel has since been closed. To provide this guarantee, the channel reserves one slot + /// in the channel for the coming send. This reserved slot is not available to other `Sender` + /// instances, so you need to be careful to not end up with deadlocks by blocking after calling + /// `poll_ready` but before sending an element. + /// + /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you + /// can use [`disarm`](Sender::disarm) to release the reserved slot. + /// + /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to + /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel + /// is closed. + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { + self.chan.poll_ready(cx).map_err(|_| ClosedError::new()) + } + + /// Undo a successful call to `poll_ready`. + /// + /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the + /// channel to make room for the coming send. `disarm` allows you to give up that slot if you + /// decide you do not wish to send an item after all. After calling `disarm`, you must call + /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again. + /// + /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was + /// not previously called, or did not succeed). + /// + /// # Motivation + /// + /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers + /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may + /// take up all the slots of the channel, and prevent active senders from getting any requests + /// through. Consider this code that forwards from one channel to another: + /// + /// ```rust,ignore + /// loop { + /// ready!(tx.poll_ready(cx))?; + /// if let Some(item) = ready!(rx.poll_recv(cx)) { + /// tx.try_send(item)?; + /// } else { + /// break; + /// } + /// } + /// ``` + /// + /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then + /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do, + /// they are effectively each reducing the channel's capacity by 1. If enough of these + /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot + /// for them through `poll_ready`, and the system will deadlock. + /// + /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that + /// you have to block. We can then fix the code above by writing: + /// + /// ```rust,ignore + /// loop { + /// ready!(tx.poll_ready(cx))?; + /// let item = rx.poll_recv(cx); + /// if let Poll::Ready(Ok(_)) = item { + /// // we're going to send the item below, so don't disarm + /// } else { + /// // give up our send slot, we won't need it for a while + /// tx.disarm(); + /// } + /// if let Some(item) = ready!(item) { + /// tx.try_send(item)?; + /// } else { + /// break; + /// } + /// } + /// ``` + pub fn disarm(&mut self) -> bool { + if self.chan.is_ready() { + self.chan.disarm(); + true + } else { + false + } + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/chan.rs b/third_party/rust/tokio/src/sync/mpsc/chan.rs new file mode 100644 index 0000000000..3466395788 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/chan.rs @@ -0,0 +1,524 @@ +use crate::loom::cell::UnsafeCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::sync::mpsc::error::{ClosedError, TryRecvError}; +use crate::sync::mpsc::{error, list}; + +use std::fmt; +use std::process; +use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +/// Channel sender +pub(crate) struct Tx<T, S: Semaphore> { + inner: Arc<Chan<T, S>>, + permit: S::Permit, +} + +impl<T, S: Semaphore> fmt::Debug for Tx<T, S> +where + S::Permit: fmt::Debug, + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx") + .field("inner", &self.inner) + .field("permit", &self.permit) + .finish() + } +} + +/// Channel receiver +pub(crate) struct Rx<T, S: Semaphore> { + inner: Arc<Chan<T, S>>, +} + +impl<T, S: Semaphore> fmt::Debug for Rx<T, S> +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx").field("inner", &self.inner).finish() + } +} + +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum TrySendError { + Closed, + Full, +} + +impl<T> From<(T, TrySendError)> for error::SendError<T> { + fn from(src: (T, TrySendError)) -> error::SendError<T> { + match src.1 { + TrySendError::Closed => error::SendError(src.0), + TrySendError::Full => unreachable!(), + } + } +} + +impl<T> From<(T, TrySendError)> for error::TrySendError<T> { + fn from(src: (T, TrySendError)) -> error::TrySendError<T> { + match src.1 { + TrySendError::Closed => error::TrySendError::Closed(src.0), + TrySendError::Full => error::TrySendError::Full(src.0), + } + } +} + +pub(crate) trait Semaphore { + type Permit; + + fn new_permit() -> Self::Permit; + + /// The permit is dropped without a value being sent. In this case, the + /// permit must be returned to the semaphore. + fn drop_permit(&self, permit: &mut Self::Permit); + + fn is_idle(&self) -> bool; + + fn add_permit(&self); + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + permit: &mut Self::Permit, + ) -> Poll<Result<(), ClosedError>>; + + fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>; + + /// A value was sent into the channel and the permit held by `tx` is + /// dropped. In this case, the permit should not immeditely be returned to + /// the semaphore. Instead, the permit is returnred to the semaphore once + /// the sent value is read by the rx handle. + fn forget(&self, permit: &mut Self::Permit); + + fn close(&self); +} + +struct Chan<T, S> { + /// Handle to the push half of the lock-free list. + tx: list::Tx<T>, + + /// Coordinates access to channel's capacity. + semaphore: S, + + /// Receiver waker. Notified when a value is pushed into the channel. + rx_waker: AtomicWaker, + + /// Tracks the number of outstanding sender handles. + /// + /// When this drops to zero, the send half of the channel is closed. + tx_count: AtomicUsize, + + /// Only accessed by `Rx` handle. + rx_fields: UnsafeCell<RxFields<T>>, +} + +impl<T, S> fmt::Debug for Chan<T, S> +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Chan") + .field("tx", &self.tx) + .field("semaphore", &self.semaphore) + .field("rx_waker", &self.rx_waker) + .field("tx_count", &self.tx_count) + .field("rx_fields", &"...") + .finish() + } +} + +/// Fields only accessed by `Rx` handle. +struct RxFields<T> { + /// Channel receiver. This field is only accessed by the `Receiver` type. + list: list::Rx<T>, + + /// `true` if `Rx::close` is called. + rx_closed: bool, +} + +impl<T> fmt::Debug for RxFields<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("RxFields") + .field("list", &self.list) + .field("rx_closed", &self.rx_closed) + .finish() + } +} + +unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} +unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} + +pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) +where + S: Semaphore, +{ + let (tx, rx) = list::channel(); + + let chan = Arc::new(Chan { + tx, + semaphore, + rx_waker: AtomicWaker::new(), + tx_count: AtomicUsize::new(1), + rx_fields: UnsafeCell::new(RxFields { + list: rx, + rx_closed: false, + }), + }); + + (Tx::new(chan.clone()), Rx::new(chan)) +} + +// ===== impl Tx ===== + +impl<T, S> Tx<T, S> +where + S: Semaphore, +{ + fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { + Tx { + inner: chan, + permit: S::new_permit(), + } + } + + pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { + self.inner.semaphore.poll_acquire(cx, &mut self.permit) + } + + pub(crate) fn disarm(&mut self) { + // TODO: should this error if not acquired? + self.inner.semaphore.drop_permit(&mut self.permit) + } + + /// Send a message and notify the receiver. + pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { + self.inner.try_send(value, &mut self.permit) + } +} + +impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> { + pub(crate) fn is_ready(&self) -> bool { + self.permit.is_acquired() + } +} + +impl<T> Tx<T, AtomicUsize> { + pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> { + self.inner.try_send(value, &mut ()) + } +} + +impl<T, S> Clone for Tx<T, S> +where + S: Semaphore, +{ + fn clone(&self) -> Tx<T, S> { + // Using a Relaxed ordering here is sufficient as the caller holds a + // strong ref to `self`, preventing a concurrent decrement to zero. + self.inner.tx_count.fetch_add(1, Relaxed); + + Tx { + inner: self.inner.clone(), + permit: S::new_permit(), + } + } +} + +impl<T, S> Drop for Tx<T, S> +where + S: Semaphore, +{ + fn drop(&mut self) { + self.inner.semaphore.drop_permit(&mut self.permit); + + if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { + return; + } + + // Close the list, which sends a `Close` message + self.inner.tx.close(); + + // Notify the receiver + self.inner.rx_waker.wake(); + } +} + +// ===== impl Rx ===== + +impl<T, S> Rx<T, S> +where + S: Semaphore, +{ + fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { + Rx { inner: chan } + } + + pub(crate) fn close(&mut self) { + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + if rx_fields.rx_closed { + return; + } + + rx_fields.rx_closed = true; + }); + + self.inner.semaphore.close(); + } + + /// Receive the next value + pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + use super::block::Read::*; + + // Keep track of task budget + ready!(crate::coop::poll_proceed(cx)); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.pop(&self.inner.tx) { + Some(Value(value)) => { + self.inner.semaphore.add_permit(); + return Ready(Some(value)); + } + Some(Closed) => { + // TODO: This check may not be required as it most + // likely can only return `true` at this point. A + // channel is closed when all tx handles are + // dropped. Dropping a tx handle releases memory, + // which ensures that if dropping the tx handle is + // visible, then all messages sent are also visible. + assert!(self.inner.semaphore.is_idle()); + return Ready(None); + } + None => {} // fall through + } + }; + } + + try_recv!(); + + self.inner.rx_waker.register_by_ref(cx.waker()); + + // It is possible that a value was pushed between attempting to read + // and registering the task, so we have to check the channel a + // second time here. + try_recv!(); + + if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + Ready(None) + } else { + Pending + } + }) + } + + /// Receives the next value without blocking + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::block::Read::*; + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + match rx_fields.list.pop(&self.inner.tx) { + Some(Value(value)) => { + self.inner.semaphore.add_permit(); + Ok(value) + } + Some(Closed) => Err(TryRecvError::Closed), + None => Err(TryRecvError::Empty), + } + }) + } +} + +impl<T, S> Drop for Rx<T, S> +where + S: Semaphore, +{ + fn drop(&mut self) { + use super::block::Read::Value; + + self.close(); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) { + self.inner.semaphore.add_permit(); + } + }) + } +} + +// ===== impl Chan ===== + +impl<T, S> Chan<T, S> +where + S: Semaphore, +{ + fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> { + if let Err(e) = self.semaphore.try_acquire(permit) { + return Err((value, e)); + } + + // Push the value + self.tx.push(value); + + // Notify the rx task + self.rx_waker.wake(); + + // Release the permit + self.semaphore.forget(permit); + + Ok(()) + } +} + +impl<T, S> Drop for Chan<T, S> { + fn drop(&mut self) { + use super::block::Read::Value; + + // Safety: the only owner of the rx fields is Chan, and eing + // inside its own Drop means we're the last ones to touch it. + self.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {} + unsafe { rx_fields.list.free_blocks() }; + }); + } +} + +use crate::sync::semaphore_ll::TryAcquireError; + +impl From<TryAcquireError> for TrySendError { + fn from(src: TryAcquireError) -> TrySendError { + if src.is_closed() { + TrySendError::Closed + } else if src.is_no_permits() { + TrySendError::Full + } else { + unreachable!(); + } + } +} + +// ===== impl Semaphore for (::Semaphore, capacity) ===== + +use crate::sync::semaphore_ll::Permit; + +impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { + type Permit = Permit; + + fn new_permit() -> Permit { + Permit::new() + } + + fn drop_permit(&self, permit: &mut Permit) { + permit.release(1, &self.0); + } + + fn add_permit(&self) { + self.0.add_permits(1) + } + + fn is_idle(&self) -> bool { + self.0.available_permits() == self.1 + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + permit: &mut Permit, + ) -> Poll<Result<(), ClosedError>> { + // Keep track of task budget + ready!(crate::coop::poll_proceed(cx)); + + permit + .poll_acquire(cx, 1, &self.0) + .map_err(|_| ClosedError::new()) + } + + fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { + permit.try_acquire(1, &self.0)?; + Ok(()) + } + + fn forget(&self, permit: &mut Self::Permit) { + permit.forget(1); + } + + fn close(&self) { + self.0.close(); + } +} + +// ===== impl Semaphore for AtomicUsize ===== + +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::usize; + +impl Semaphore for AtomicUsize { + type Permit = (); + + fn new_permit() {} + + fn drop_permit(&self, _permit: &mut ()) {} + + fn add_permit(&self) { + let prev = self.fetch_sub(2, Release); + + if prev >> 1 == 0 { + // Something went wrong + process::abort(); + } + } + + fn is_idle(&self) -> bool { + self.load(Acquire) >> 1 == 0 + } + + fn poll_acquire( + &self, + _cx: &mut Context<'_>, + permit: &mut (), + ) -> Poll<Result<(), ClosedError>> { + Ready(self.try_acquire(permit).map_err(|_| ClosedError::new())) + } + + fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> { + let mut curr = self.load(Acquire); + + loop { + if curr & 1 == 1 { + return Err(TrySendError::Closed); + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) { + Ok(_) => return Ok(()), + Err(actual) => { + curr = actual; + } + } + } + } + + fn forget(&self, _permit: &mut ()) {} + + fn close(&self) { + self.fetch_or(1, Release); + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/error.rs b/third_party/rust/tokio/src/sync/mpsc/error.rs new file mode 100644 index 0000000000..72c42aa53e --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/error.rs @@ -0,0 +1,146 @@ +//! Channel error types + +use std::error::Error; +use std::fmt; + +/// Error returned by the `Sender`. +#[derive(Debug)] +pub struct SendError<T>(pub T); + +impl<T> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for SendError<T> {} + +// ===== TrySendError ===== + +/// This enumeration is the list of the possible error outcomes for the +/// [try_send](super::Sender::try_send) method. +#[derive(Debug)] +pub enum TrySendError<T> { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), +} + +impl<T: fmt::Debug> Error for TrySendError<T> {} + +impl<T> fmt::Display for TrySendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TrySendError::Full(..) => "no available capacity", + TrySendError::Closed(..) => "channel closed", + } + ) + } +} + +impl<T> From<SendError<T>> for TrySendError<T> { + fn from(src: SendError<T>) -> TrySendError<T> { + TrySendError::Closed(src.0) + } +} + +// ===== RecvError ===== + +/// Error returned by `Receiver`. +#[derive(Debug)] +pub struct RecvError(()); + +impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl Error for RecvError {} + +// ===== TryRecvError ===== + +/// This enumeration is the list of the possible reasons that try_recv +/// could not return data when called. +#[derive(Debug, PartialEq)] +pub enum TryRecvError { + /// This channel is currently empty, but the Sender(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The channel's sending half has been closed, and there will + /// never be any more data received on it. + Closed, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TryRecvError::Empty => "channel empty", + TryRecvError::Closed => "channel closed", + } + ) + } +} + +impl Error for TryRecvError {} + +// ===== ClosedError ===== + +/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). +#[derive(Debug)] +pub struct ClosedError(()); + +impl ClosedError { + pub(crate) fn new() -> ClosedError { + ClosedError(()) + } +} + +impl fmt::Display for ClosedError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl Error for ClosedError {} + +cfg_time! { + // ===== SendTimeoutError ===== + + #[derive(Debug)] + /// Error returned by [`Sender::send_timeout`](super::Sender::send_timeout)]. + pub enum SendTimeoutError<T> { + /// The data could not be sent on the channel because the channel is + /// full, and the timeout to send has elapsed. + Timeout(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), + } + + impl<T: fmt::Debug> Error for SendTimeoutError<T> {} + + impl<T> fmt::Display for SendTimeoutError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + SendTimeoutError::Timeout(..) => "timed out waiting on send operation", + SendTimeoutError::Closed(..) => "channel closed", + } + ) + } + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/list.rs b/third_party/rust/tokio/src/sync/mpsc/list.rs new file mode 100644 index 0000000000..53f82a25ef --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/list.rs @@ -0,0 +1,341 @@ +//! A concurrent, lock-free, FIFO list. + +use crate::loom::{ + sync::atomic::{AtomicPtr, AtomicUsize}, + thread, +}; +use crate::sync::mpsc::block::{self, Block}; + +use std::fmt; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; + +/// List queue transmit handle +pub(crate) struct Tx<T> { + /// Tail in the `Block` mpmc list. + block_tail: AtomicPtr<Block<T>>, + + /// Position to push the next message. This reference a block and offset + /// into the block. + tail_position: AtomicUsize, +} + +/// List queue receive handle +pub(crate) struct Rx<T> { + /// Pointer to the block being processed + head: NonNull<Block<T>>, + + /// Next slot index to process + index: usize, + + /// Pointer to the next block pending release + free_head: NonNull<Block<T>>, +} + +pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) { + // Create the initial block shared between the tx and rx halves. + let initial_block = Box::new(Block::new(0)); + let initial_block_ptr = Box::into_raw(initial_block); + + let tx = Tx { + block_tail: AtomicPtr::new(initial_block_ptr), + tail_position: AtomicUsize::new(0), + }; + + let head = NonNull::new(initial_block_ptr).unwrap(); + + let rx = Rx { + head, + index: 0, + free_head: head, + }; + + (tx, rx) +} + +impl<T> Tx<T> { + /// Pushes a value into the list. + pub(crate) fn push(&self, value: T) { + // First, claim a slot for the value. `Acquire` is used here to + // synchronize with the `fetch_add` in `reclaim_blocks`. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + // Load the current block and write the value + let block = self.find_block(slot_index); + + unsafe { + // Write the value to the block + block.as_ref().write(slot_index, value); + } + } + + /// Closes the send half of the list + /// + /// Similar process as pushing a value, but instead of writing the value & + /// setting the ready flag, the TX_CLOSED flag is set on the block. + pub(crate) fn close(&self) { + // First, claim a slot for the value. This is the last slot that will be + // claimed. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + let block = self.find_block(slot_index); + + unsafe { block.as_ref().tx_close() } + } + + fn find_block(&self, slot_index: usize) -> NonNull<Block<T>> { + // The start index of the block that contains `index`. + let start_index = block::start_index(slot_index); + + // The index offset into the block + let offset = block::offset(slot_index); + + // Load the current head of the block + let mut block_ptr = self.block_tail.load(Acquire); + + let block = unsafe { &*block_ptr }; + + // Calculate the distance between the tail ptr and the target block + let distance = block.distance(start_index); + + // Decide if this call to `find_block` should attempt to update the + // `block_tail` pointer. + // + // Updating `block_tail` is not always performed in order to reduce + // contention. + // + // When set, as the routine walks the linked list, it attempts to update + // `block_tail`. If the update cannot be performed, `try_updating_tail` + // is unset. + let mut try_updating_tail = distance > offset; + + // Walk the linked list of blocks until the block with `start_index` is + // found. + loop { + let block = unsafe { &(*block_ptr) }; + + if block.is_at_index(start_index) { + return unsafe { NonNull::new_unchecked(block_ptr) }; + } + + let next_block = block + .load_next(Acquire) + // There is no allocated next block, grow the linked list. + .unwrap_or_else(|| block.grow()); + + // If the block is **not** final, then the tail pointer cannot be + // advanced any more. + try_updating_tail &= block.is_final(); + + if try_updating_tail { + // Advancing `block_tail` must happen when walking the linked + // list. `block_tail` may not advance passed any blocks that are + // not "final". At the point a block is finalized, it is unknown + // if there are any prior blocks that are unfinalized, which + // makes it impossible to advance `block_tail`. + // + // While walking the linked list, `block_tail` can be advanced + // as long as finalized blocks are traversed. + // + // Release ordering is used to ensure that any subsequent reads + // are able to see the memory pointed to by `block_tail`. + // + // Acquire is not needed as any "actual" value is not accessed. + // At this point, the linked list is walked to acquire blocks. + let actual = + self.block_tail + .compare_and_swap(block_ptr, next_block.as_ptr(), Release); + + if actual == block_ptr { + // Synchronize with any senders + let tail_position = self.tail_position.fetch_add(0, Release); + + unsafe { + block.tx_release(tail_position); + } + } else { + // A concurrent sender is also working on advancing + // `block_tail` and this thread is falling behind. + // + // Stop trying to advance the tail pointer + try_updating_tail = false; + } + } + + block_ptr = next_block.as_ptr(); + + thread::yield_now(); + } + } + + pub(crate) unsafe fn reclaim_block(&self, mut block: NonNull<Block<T>>) { + // The block has been removed from the linked list and ownership + // is reclaimed. + // + // Before dropping the block, see if it can be reused by + // inserting it back at the end of the linked list. + // + // First, reset the data + block.as_mut().reclaim(); + + let mut reused = false; + + // Attempt to insert the block at the end + // + // Walk at most three times + // + let curr_ptr = self.block_tail.load(Acquire); + + // The pointer can never be null + debug_assert!(!curr_ptr.is_null()); + + let mut curr = NonNull::new_unchecked(curr_ptr); + + // TODO: Unify this logic with Block::grow + for _ in 0..3 { + match curr.as_ref().try_push(&mut block, AcqRel) { + Ok(_) => { + reused = true; + break; + } + Err(next) => { + curr = next; + } + } + } + + if !reused { + let _ = Box::from_raw(block.as_ptr()); + } + } +} + +impl<T> fmt::Debug for Tx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx") + .field("block_tail", &self.block_tail.load(Relaxed)) + .field("tail_position", &self.tail_position.load(Relaxed)) + .finish() + } +} + +impl<T> Rx<T> { + /// Pops the next value off the queue + pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> { + // Advance `head`, if needed + if !self.try_advancing_head() { + return None; + } + + self.reclaim_blocks(tx); + + unsafe { + let block = self.head.as_ref(); + + let ret = block.read(self.index); + + if let Some(block::Read::Value(..)) = ret { + self.index = self.index.wrapping_add(1); + } + + ret + } + } + + /// Tries advancing the block pointer to the block referenced by `self.index`. + /// + /// Returns `true` if successful, `false` if there is no next block to load. + fn try_advancing_head(&mut self) -> bool { + let block_index = block::start_index(self.index); + + loop { + let next_block = { + let block = unsafe { self.head.as_ref() }; + + if block.is_at_index(block_index) { + return true; + } + + block.load_next(Acquire) + }; + + let next_block = match next_block { + Some(next_block) => next_block, + None => { + return false; + } + }; + + self.head = next_block; + + thread::yield_now(); + } + } + + fn reclaim_blocks(&mut self, tx: &Tx<T>) { + while self.free_head != self.head { + unsafe { + // Get a handle to the block that will be freed and update + // `free_head` to point to the next block. + let block = self.free_head; + + let observed_tail_position = block.as_ref().observed_tail_position(); + + let required_index = match observed_tail_position { + Some(i) => i, + None => return, + }; + + if required_index > self.index { + return; + } + + // We may read the next pointer with `Relaxed` ordering as it is + // guaranteed that the `reclaim_blocks` routine trails the `recv` + // routine. Any memory accessed by `reclaim_blocks` has already + // been acquired by `recv`. + let next_block = block.as_ref().load_next(Relaxed); + + // Update the free list head + self.free_head = next_block.unwrap(); + + // Push the emptied block onto the back of the queue, making it + // available to senders. + tx.reclaim_block(block); + } + + thread::yield_now(); + } + } + + /// Effectively `Drop` all the blocks. Should only be called once, when + /// the list is dropping. + pub(super) unsafe fn free_blocks(&mut self) { + debug_assert_ne!(self.free_head, NonNull::dangling()); + + let mut cur = Some(self.free_head); + + #[cfg(debug_assertions)] + { + // to trigger the debug assert above so as to catch that we + // don't call `free_blocks` more than once. + self.free_head = NonNull::dangling(); + self.head = NonNull::dangling(); + } + + while let Some(block) = cur { + cur = block.as_ref().load_next(Relaxed); + drop(Box::from_raw(block.as_ptr())); + } + } +} + +impl<T> fmt::Debug for Rx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx") + .field("head", &self.head) + .field("index", &self.index) + .field("free_head", &self.free_head) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/mod.rs b/third_party/rust/tokio/src/sync/mpsc/mod.rs new file mode 100644 index 0000000000..4cfd6150f3 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/mod.rs @@ -0,0 +1,64 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A multi-producer, single-consumer queue for sending values across +//! asynchronous tasks. +//! +//! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`] +//! handles. [`Receiver`] implements `Stream` and allows a task to read values +//! out of the channel. If there is no message to read, the current task will be +//! notified when a new value is sent. [`Sender`] implements the `Sink` trait +//! and allows sending messages into the channel. If the channel is at capacity, +//! the send is rejected and the task will be notified when additional capacity +//! is available. In other words, the channel provides backpressure. +//! +//! Unbounded channels are also available using the `unbounded_channel` +//! constructor. +//! +//! # Disconnection +//! +//! When all [`Sender`] handles have been dropped, it is no longer +//! possible to send values into the channel. This is considered the termination +//! event of the stream. As such, `Receiver::poll` returns `Ok(Ready(None))`. +//! +//! If the [`Receiver`] handle is dropped, then messages can no longer +//! be read out of the channel. In this case, all further attempts to send will +//! result in an error. +//! +//! # Clean Shutdown +//! +//! When the [`Receiver`] is dropped, it is possible for unprocessed messages to +//! remain in the channel. Instead, it is usually desirable to perform a "clean" +//! shutdown. To do this, the receiver first calls `close`, which will prevent +//! any further messages to be sent into the channel. Then, the receiver +//! consumes the channel to completion, at which point the receiver can be +//! dropped. +//! +//! [`Sender`]: crate::sync::mpsc::Sender +//! [`Receiver`]: crate::sync::mpsc::Receiver + +pub(super) mod block; + +mod bounded; +pub use self::bounded::{channel, Receiver, Sender}; + +mod chan; + +pub(super) mod list; + +mod unbounded; +pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender}; + +pub mod error; + +/// The number of values a block can contain. +/// +/// This value must be a power of 2. It also must be smaller than the number of +/// bits in `usize`. +#[cfg(all(target_pointer_width = "64", not(loom)))] +const BLOCK_CAP: usize = 32; + +#[cfg(all(not(target_pointer_width = "64"), not(loom)))] +const BLOCK_CAP: usize = 16; + +#[cfg(loom)] +const BLOCK_CAP: usize = 2; diff --git a/third_party/rust/tokio/src/sync/mpsc/unbounded.rs b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs new file mode 100644 index 0000000000..ba543fe4c8 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs @@ -0,0 +1,176 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{SendError, TryRecvError}; + +use std::fmt; +use std::task::{Context, Poll}; + +/// Send values to the associated `UnboundedReceiver`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +pub struct UnboundedSender<T> { + chan: chan::Tx<T, Semaphore>, +} + +impl<T> Clone for UnboundedSender<T> { + fn clone(&self) -> Self { + UnboundedSender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for UnboundedSender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedSender") + .field("chan", &self.chan) + .finish() + } +} + +/// Receive values from the associated `UnboundedSender`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +pub struct UnboundedReceiver<T> { + /// The channel receiver + chan: chan::Rx<T, Semaphore>, +} + +impl<T> fmt::Debug for UnboundedReceiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedReceiver") + .field("chan", &self.chan) + .finish() + } +} + +/// Creates an unbounded mpsc channel for communicating between asynchronous +/// tasks. +/// +/// A `send` on this channel will always succeed as long as the receive half has +/// not been closed. If the receiver falls behind, messages will be arbitrarily +/// buffered. +/// +/// **Note** that the amount of available system memory is an implicit bound to +/// the channel. Using an `unbounded` channel has the ability of causing the +/// process to run out of memory. In this case, the process will be aborted. +pub fn unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) { + let (tx, rx) = chan::channel(AtomicUsize::new(0)); + + let tx = UnboundedSender::new(tx); + let rx = UnboundedReceiver::new(rx); + + (tx, rx) +} + +/// No capacity +type Semaphore = AtomicUsize; + +impl<T> UnboundedReceiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> UnboundedReceiver<T> { + UnboundedReceiver { chan } + } + + #[doc(hidden)] // TODO: doc + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } + + /// Receives the next value for this receiver. + /// + /// `None` is returned when all `Sender` halves have dropped, indicating + /// that no further values can be sent on the channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tokio::spawn(async move { + /// tx.send("hello").unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// tx.send("world").unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_recv(cx)).await + } + + /// Attempts to return a pending value on this receiver without blocking. + /// + /// This method will never block the caller in order to wait for data to + /// become available. Instead, this will always return immediately with + /// a possible option of pending data on the channel. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// block on a receiver. + /// + /// Compared with recv, this function has two failure cases instead of + /// one (one for disconnection, one for an empty buffer). + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Closes the receiving half of a channel, without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + pub fn close(&mut self) { + self.chan.close(); + } +} + +#[cfg(feature = "stream")] +impl<T> crate::stream::Stream for UnboundedReceiver<T> { + type Item = T; + + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.poll_recv(cx) + } +} + +impl<T> UnboundedSender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> UnboundedSender<T> { + UnboundedSender { chan } + } + + /// Attempts to send a message on this `UnboundedSender` without blocking. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`UnboundedReceiver`] having been dropped, + /// the function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: UnboundedReceiver::close + /// [`UnboundedReceiver`]: UnboundedReceiver + pub fn send(&self, message: T) -> Result<(), SendError<T>> { + self.chan.send_unbounded(message)?; + Ok(()) + } +} diff --git a/third_party/rust/tokio/src/sync/mutex.rs b/third_party/rust/tokio/src/sync/mutex.rs new file mode 100644 index 0000000000..7167906de1 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mutex.rs @@ -0,0 +1,228 @@ +//! An asynchronous `Mutex`-like type. +//! +//! This module provides [`Mutex`], a type that acts similarly to an asynchronous `Mutex`, with one +//! major difference: the [`MutexGuard`] returned by `lock` is not tied to the lifetime of the +//! `Mutex`. This enables you to acquire a lock, and then pass that guard into a future, and then +//! release it at some later point in time. +//! +//! This allows you to do something along the lines of: +//! +//! ```rust,no_run +//! use tokio::sync::Mutex; +//! use std::sync::Arc; +//! +//! #[tokio::main] +//! async fn main() { +//! let data1 = Arc::new(Mutex::new(0)); +//! let data2 = Arc::clone(&data1); +//! +//! tokio::spawn(async move { +//! let mut lock = data2.lock().await; +//! *lock += 1; +//! }); +//! +//! let mut lock = data1.lock().await; +//! *lock += 1; +//! } +//! ``` +//! +//! Another example +//! ```rust,no_run +//! #![warn(rust_2018_idioms)] +//! +//! use tokio::sync::Mutex; +//! use std::sync::Arc; +//! +//! +//! #[tokio::main] +//! async fn main() { +//! let count = Arc::new(Mutex::new(0)); +//! +//! for _ in 0..5 { +//! let my_count = Arc::clone(&count); +//! tokio::spawn(async move { +//! for _ in 0..10 { +//! let mut lock = my_count.lock().await; +//! *lock += 1; +//! println!("{}", lock); +//! } +//! }); +//! } +//! +//! loop { +//! if *count.lock().await >= 50 { +//! break; +//! } +//! } +//! println!("Count hit 50."); +//! } +//! ``` +//! There are a few things of note here to pay attention to in this example. +//! 1. The mutex is wrapped in an [`std::sync::Arc`] to allow it to be shared across threads. +//! 2. Each spawned task obtains a lock and releases it on every iteration. +//! 3. Mutation of the data the Mutex is protecting is done by de-referencing the the obtained lock +//! as seen on lines 23 and 30. +//! +//! Tokio's Mutex works in a simple FIFO (first in, first out) style where as requests for a lock are +//! made Tokio will queue them up and provide a lock when it is that requester's turn. In that way +//! the Mutex is "fair" and predictable in how it distributes the locks to inner data. This is why +//! the output of this program is an in-order count to 50. Locks are released and reacquired +//! after every iteration, so basically, each thread goes to the back of the line after it increments +//! the value once. Also, since there is only a single valid lock at any given time there is no +//! possibility of a race condition when mutating the inner value. +//! +//! Note that in contrast to `std::sync::Mutex`, this implementation does not +//! poison the mutex when a thread holding the `MutexGuard` panics. In such a +//! case, the mutex will be unlocked. If the panic is caught, this might leave +//! the data protected by the mutex in an inconsistent state. +//! +//! [`Mutex`]: struct@Mutex +//! [`MutexGuard`]: struct@MutexGuard +use crate::coop::CoopFutureExt; +use crate::sync::batch_semaphore as semaphore; + +use std::cell::UnsafeCell; +use std::error::Error; +use std::fmt; +use std::ops::{Deref, DerefMut}; + +/// An asynchronous mutual exclusion primitive useful for protecting shared data +/// +/// Each mutex has a type parameter (`T`) which represents the data that it is protecting. The data +/// can only be accessed through the RAII guards returned from `lock`, which +/// guarantees that the data is only ever accessed when the mutex is locked. +#[derive(Debug)] +pub struct Mutex<T> { + c: UnsafeCell<T>, + s: semaphore::Semaphore, +} + +/// A handle to a held `Mutex`. +/// +/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard +/// internally keeps a reference-couned pointer to the original `Mutex`, so even if the lock goes +/// away, the guard remains valid. +/// +/// The lock is automatically released whenever the guard is dropped, at which point `lock` +/// will succeed yet again. +pub struct MutexGuard<'a, T> { + lock: &'a Mutex<T>, +} + +// As long as T: Send, it's fine to send and share Mutex<T> between threads. +// If T was not Send, sending and sharing a Mutex<T> would be bad, since you can access T through +// Mutex<T>. +unsafe impl<T> Send for Mutex<T> where T: Send {} +unsafe impl<T> Sync for Mutex<T> where T: Send {} +unsafe impl<'a, T> Sync for MutexGuard<'a, T> where T: Send + Sync {} + +/// Error returned from the [`Mutex::try_lock`] function. +/// +/// A `try_lock` operation can only fail if the mutex is already locked. +/// +/// [`Mutex::try_lock`]: Mutex::try_lock +#[derive(Debug)] +pub struct TryLockError(()); + +impl fmt::Display for TryLockError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}", "operation would block") + } +} + +impl Error for TryLockError {} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + check_send::<MutexGuard<'_, u32>>(); + check_unpin::<Mutex<u32>>(); + check_send_sync::<Mutex<u32>>(); + + let mutex = Mutex::new(1); + check_send_sync_val(mutex.lock()); +} + +impl<T> Mutex<T> { + /// Creates a new lock in an unlocked state ready for use. + pub fn new(t: T) -> Self { + Self { + c: UnsafeCell::new(t), + s: semaphore::Semaphore::new(1), + } + } + + /// A future that resolves on acquiring the lock and returns the `MutexGuard`. + pub async fn lock(&self) -> MutexGuard<'_, T> { + self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + MutexGuard { lock: self } + } + + /// Tries to acquire the lock + pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => Ok(MutexGuard { lock: self }), + Err(_) => Err(TryLockError(())), + } + } + + /// Consumes the mutex, returning the underlying data. + pub fn into_inner(self) -> T { + self.c.into_inner() + } +} + +impl<'a, T> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.lock.s.release(1) + } +} + +impl<T> From<T> for Mutex<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T> Default for Mutex<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +impl<'a, T> Deref for MutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl<'a, T> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<'a, T: fmt::Debug> fmt::Debug for MutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/third_party/rust/tokio/src/sync/notify.rs b/third_party/rust/tokio/src/sync/notify.rs new file mode 100644 index 0000000000..5cb41e89ea --- /dev/null +++ b/third_party/rust/tokio/src/sync/notify.rs @@ -0,0 +1,556 @@ +use crate::loom::sync::atomic::AtomicU8; +use crate::loom::sync::Mutex; +use crate::util::linked_list::{self, LinkedList}; + +use std::cell::UnsafeCell; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; + +/// Notify a single task to wake up. +/// +/// `Notify` provides a basic mechanism to notify a single task of an event. +/// `Notify` itself does not carry any data. Instead, it is to be used to signal +/// another task to perform an operation. +/// +/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. +/// [`notified().await`] waits for a permit to become available, and [`notify()`] +/// sets a permit **if there currently are no available permits**. +/// +/// The synchronization details of `Notify` are similar to +/// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] +/// value contains a single permit. [`notified().await`] waits for the permit to +/// be made available, consumes the permit, and resumes. [`notify()`] sets the +/// permit, waking a pending task if there is one. +/// +/// If `notify()` is called **before** `notfied().await`, then the next call to +/// `notified().await` will complete immediately, consuming the permit. Any +/// subsequent calls to `notified().await` will wait for a new permit. +/// +/// If `notify()` is called **multiple** times before `notified().await`, only a +/// **single** permit is stored. The next call to `notified().await` will +/// complete immediately, but the one after will wait for a new permit. +/// +/// # Examples +/// +/// Basic usage. +/// +/// ``` +/// use tokio::sync::Notify; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let notify = Arc::new(Notify::new()); +/// let notify2 = notify.clone(); +/// +/// tokio::spawn(async move { +/// notify2.notified().await; +/// println!("received notification"); +/// }); +/// +/// println!("sending notification"); +/// notify.notify(); +/// } +/// ``` +/// +/// Unbound mpsc channel. +/// +/// ``` +/// use tokio::sync::Notify; +/// +/// use std::collections::VecDeque; +/// use std::sync::Mutex; +/// +/// struct Channel<T> { +/// values: Mutex<VecDeque<T>>, +/// notify: Notify, +/// } +/// +/// impl<T> Channel<T> { +/// pub fn send(&self, value: T) { +/// self.values.lock().unwrap() +/// .push_back(value); +/// +/// // Notify the consumer a value is available +/// self.notify.notify(); +/// } +/// +/// pub async fn recv(&self) -> T { +/// loop { +/// // Drain values +/// if let Some(value) = self.values.lock().unwrap().pop_front() { +/// return value; +/// } +/// +/// // Wait for values to be available +/// self.notify.notified().await; +/// } +/// } +/// } +/// ``` +/// +/// [park]: std::thread::park +/// [unpark]: std::thread::Thread::unpark +/// [`notified().await`]: Notify::notified() +/// [`notify()`]: Notify::notify() +/// [`Semaphore`]: crate::sync::Semaphore +#[derive(Debug)] +pub struct Notify { + state: AtomicU8, + waiters: Mutex<LinkedList<Waiter>>, +} + +#[derive(Debug)] +struct Waiter { + /// Intrusive linked-list pointers + pointers: linked_list::Pointers<Waiter>, + + /// Waiting task's waker + waker: Option<Waker>, + + /// `true` if the notification has been assigned to this waiter. + notified: bool, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +/// Future returned from `notified()` +#[derive(Debug)] +struct Notified<'a> { + /// The `Notify` being received on. + notify: &'a Notify, + + /// The current state of the receiving process. + state: State, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, +} + +unsafe impl<'a> Send for Notified<'a> {} +unsafe impl<'a> Sync for Notified<'a> {} + +#[derive(Debug)] +enum State { + Init, + Waiting, + Done, +} + +/// Initial "idle" state +const EMPTY: u8 = 0; + +/// One or more threads are currently waiting to be notified. +const WAITING: u8 = 1; + +/// Pending notification +const NOTIFIED: u8 = 2; + +impl Notify { + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// let notify = Notify::new(); + /// ``` + pub fn new() -> Notify { + Notify { + state: AtomicU8::new(0), + waiters: Mutex::new(LinkedList::new()), + } + } + + /// Wait for a notification. + /// + /// Each `Notify` value holds a single permit. If a permit is available from + /// an earlier call to [`notify()`], then `notified().await` will complete + /// immediately, consuming that permit. Otherwise, `notified().await` waits + /// for a permit to be made available by the next call to `notify()`. + /// + /// [`notify()`]: Notify::notify + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify(); + /// } + /// ``` + pub async fn notified(&self) { + Notified { + notify: self, + state: State::Init, + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + notified: false, + _p: PhantomPinned, + }), + } + .await + } + + /// Notifies a waiting task + /// + /// If a task is currently waiting, that task is notified. Otherwise, a + /// permit is stored in this `Notify` value and the **next** call to + /// [`notified().await`] will complete immediately consuming the permit made + /// available by this call to `notify()`. + /// + /// At most one permit may be stored by `Notify`. Many sequential calls to + /// `notify` will result in a single permit being stored. The next call to + /// `notified().await` will complete immediately, but the one after that + /// will wait. + /// + /// [`notified().await`]: Notify::notified() + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify(); + /// } + /// ``` + pub fn notify(&self) { + // Load the current state + let mut curr = self.state.load(SeqCst); + + // If the state is `EMPTY`, transition to `NOTIFIED` and return. + while let EMPTY | NOTIFIED = curr { + // The compare-exchange from `NOTIFIED` -> `NOTIFIED` is intended. A + // happens-before synchronization must happen between this atomic + // operation and a task calling `notified().await`. + let res = self.state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst); + + match res { + // No waiters, no further work to do + Ok(_) => return, + Err(actual) => { + curr = actual; + } + } + } + + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock().unwrap(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + curr = self.state.load(SeqCst); + + if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) { + drop(waiters); + waker.wake(); + } + } +} + +impl Default for Notify { + fn default() -> Notify { + Notify::new() + } +} + +fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> { + loop { + match curr { + EMPTY | NOTIFIED => { + let res = state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst); + + match res { + Ok(_) => return None, + Err(actual) => { + assert!(actual == EMPTY || actual == NOTIFIED); + state.store(NOTIFIED, SeqCst); + return None; + } + } + } + WAITING => { + // At this point, it is guaranteed that the state will not + // concurrently change as holding the lock is required to + // transition **out** of `WAITING`. + // + // Get a pending waiter + let mut waiter = waiters.pop_back().unwrap(); + + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(!waiter.notified); + + waiter.notified = true; + let waker = waiter.waker.take(); + + if waiters.is_empty() { + // As this the **final** waiter in the list, the state + // must be transitioned to `EMPTY`. As transitioning + // **from** `WAITING` requires the lock to be held, a + // `store` is sufficient. + state.store(EMPTY, SeqCst); + } + + return waker; + } + _ => unreachable!(), + } + } +} + +// ===== impl Notified ===== + +impl Notified<'_> { + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) { + unsafe { + // Safety: both `notify` and `state` are `Unpin`. + + is_unpin::<&Notify>(); + is_unpin::<AtomicU8>(); + + let me = self.get_unchecked_mut(); + (&me.notify, &mut me.state, &me.waiter) + } + } +} + +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + use State::*; + + let (notify, state, waiter) = self.project(); + + loop { + match *state { + Init => { + // Optimistically try acquiring a pending notification + let res = notify + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst); + + if res.is_ok() { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + + // Acquire the lock and attempt to transition to the waiting + // state. + let mut waiters = notify.waiters.lock().unwrap(); + + // Reload the state with the lock held + let mut curr = notify.state.load(SeqCst); + + // Transition the state to WAITING. + loop { + match curr { + EMPTY => { + // Transition to WAITING + let res = notify + .state + .compare_exchange(EMPTY, WAITING, SeqCst, SeqCst); + + if let Err(actual) = res { + assert_eq!(actual, NOTIFIED); + curr = actual; + } else { + break; + } + } + WAITING => break, + NOTIFIED => { + // Try consuming the notification + let res = notify + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst); + + match res { + Ok(_) => { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + Err(actual) => { + assert_eq!(actual, EMPTY); + curr = actual; + } + } + } + _ => unreachable!(), + } + } + + // Safety: called while locked. + unsafe { + (*waiter.get()).waker = Some(cx.waker().clone()); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + + *state = Waiting; + } + Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = notify.waiters.lock().unwrap(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.notified { + // Our waker has been notified. Reset the fields and + // remove it from the list. + w.waker = None; + w.notified = false; + + *state = Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + Done => { + return Poll::Ready(()); + } + } + } + } +} + +impl Drop for Notified<'_> { + fn drop(&mut self) { + use State::*; + + // Safety: The type only transitions to a "Waiting" state when pinned. + let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() }; + + // This is where we ensure safety. The `Notified` value is being + // dropped, which means we must ensure that the waiter entry is no + // longer stored in the linked list. + if let Waiting = *state { + let mut notify_state = WAITING; + let mut waiters = notify.waiters.lock().unwrap(); + + // `Notify.state` may be in any of the three states (Empty, Waiting, + // Notified). It doesn't actually matter what the atomic is set to + // at this point. We hold the lock and will ensure the atomic is in + // the correct state once th elock is dropped. + // + // Because the atomic state is not checked, at first glance, it may + // seem like this routine does not handle the case where the + // receiver is notified but has not yet observed the notification. + // If this happens, no matter how many notifications happen between + // this receiver being notified and the receive future dropping, all + // we need to do is ensure that one notification is returned back to + // the `Notify`. This is done by calling `notify_locked` if `self` + // has the `notified` flag set. + + // remove the entry from the list + // + // safety: the waiter is only added to `waiters` by virtue of it + // being the only `LinkedList` available to the type. + unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; + + if waiters.is_empty() { + notify_state = EMPTY; + // If the state *should* be `NOTIFIED`, the call to + // `notify_locked` below will end up doing the + // `store(NOTIFIED)`. If a concurrent receiver races and + // observes the incorrect `EMPTY` state, it will then obtain the + // lock and block until `notify.state` is in the correct final + // state. + notify.state.store(EMPTY, SeqCst); + } + + // See if the node was notified but not received. In this case, the + // notification must be sent to another waiter. + // + // Safety: with the entry removed from the linked list, there can be + // no concurrent access to the entry + let notified = unsafe { (*waiter.get()).notified }; + + if notified { + if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { + drop(waiters); + waker.wake(); + } + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} + +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/sync/oneshot.rs b/third_party/rust/tokio/src/sync/oneshot.rs new file mode 100644 index 0000000000..62ad484eec --- /dev/null +++ b/third_party/rust/tokio/src/sync/oneshot.rs @@ -0,0 +1,784 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A channel for sending a single message between asynchronous tasks. + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; + +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll, Waker}; + +/// Sends a value to the associated `Receiver`. +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Sender<T> { + inner: Option<Arc<Inner<T>>>, +} + +/// Receive a value from the associated `Sender`. +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Receiver<T> { + inner: Option<Arc<Inner<T>>>, +} + +pub mod error { + //! Oneshot error types + + use std::fmt; + + /// Error returned by the `Future` implementation for `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub struct RecvError(pub(super) ()); + + /// Error returned by the `try_recv` function on `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub enum TryRecvError { + /// The send half of the channel has not yet sent a value. + Empty, + + /// The send half of the channel was dropped without sending a value. + Closed, + } + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} + + // ===== impl TryRecvError ===== + + impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(fmt, "channel empty"), + TryRecvError::Closed => write!(fmt, "channel closed"), + } + } + } + + impl std::error::Error for TryRecvError {} +} + +use self::error::*; + +struct Inner<T> { + /// Manages the state of the inner cell + state: AtomicUsize, + + /// The value. This is set by `Sender` and read by `Receiver`. The state of + /// the cell is tracked by `state`. + value: UnsafeCell<Option<T>>, + + /// The task to notify when the receiver drops without consuming the value. + tx_task: UnsafeCell<MaybeUninit<Waker>>, + + /// The task to notify when the value is sent. + rx_task: UnsafeCell<MaybeUninit<Waker>>, +} + +#[derive(Clone, Copy)] +struct State(usize); + +/// Create a new one-shot channel for sending single values across asynchronous +/// tasks. +/// +/// The function returns separate "send" and "receive" handles. The `Sender` +/// handle is used by the producer to send the value. The `Receiver` handle is +/// used by the consumer to receive the value. +/// +/// Each handle can be used on separate tasks. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +pub fn channel<T>() -> (Sender<T>, Receiver<T>) { + #[allow(deprecated)] + let inner = Arc::new(Inner { + state: AtomicUsize::new(State::new().as_usize()), + value: UnsafeCell::new(None), + tx_task: UnsafeCell::new(MaybeUninit::uninit()), + rx_task: UnsafeCell::new(MaybeUninit::uninit()), + }); + + let tx = Sender { + inner: Some(inner.clone()), + }; + let rx = Receiver { inner: Some(inner) }; + + (tx, rx) +} + +impl<T> Sender<T> { + /// Attempts to send a value on this channel, returning it back if it could + /// not be sent. + /// + /// The function consumes `self` as only one value may ever be sent on a + /// one-shot channel. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been deallocated. Note that a + /// return value of `Err` means that the data will never be received, but + /// a return value of `Ok` does *not* mean that the data will be received. + /// It is possible for the corresponding receiver to hang up immediately + /// after this function returns `Ok`. + /// + /// # Examples + /// + /// Send a value to another task + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// if let Err(_) = tx.send(3) { + /// println!("the receiver dropped"); + /// } + /// }); + /// + /// match rx.await { + /// Ok(v) => println!("got = {:?}", v), + /// Err(_) => println!("the sender dropped"), + /// } + /// } + /// ``` + pub fn send(mut self, t: T) -> Result<(), T> { + let inner = self.inner.take().unwrap(); + + inner.value.with_mut(|ptr| unsafe { + *ptr = Some(t); + }); + + if !inner.complete() { + return Err(inner + .value + .with_mut(|ptr| unsafe { (*ptr).take() }.unwrap())); + } + + Ok(()) + } + + #[doc(hidden)] // TODO: remove + pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // Keep track of task budget + ready!(crate::coop::poll_proceed(cx)); + + let inner = self.inner.as_ref().unwrap(); + + let mut state = State::load(&inner.state, Acquire); + + if state.is_closed() { + return Poll::Ready(()); + } + + if state.is_tx_task_set() { + let will_notify = unsafe { inner.with_tx_task(|w| w.will_wake(cx.waker())) }; + + if !will_notify { + state = State::unset_tx_task(&inner.state); + + if state.is_closed() { + // Set the flag again so that the waker is released in drop + State::set_tx_task(&inner.state); + return Ready(()); + } else { + unsafe { inner.drop_tx_task() }; + } + } + } + + if !state.is_tx_task_set() { + // Attempt to set the task + unsafe { + inner.set_tx_task(cx); + } + + // Update the state + state = State::set_tx_task(&inner.state); + + if state.is_closed() { + return Ready(()); + } + } + + Pending + } + + /// Waits for the associated [`Receiver`] handle to close. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// This function is useful when paired with `select!` to abort a + /// computation when the receiver is no longer interested in the result. + /// + /// # Return + /// + /// Returns a `Future` which must be awaited on. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// Basic usage + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel::<()>(); + /// + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// tx.closed().await; + /// println!("the receiver dropped"); + /// } + /// ``` + /// + /// Paired with select + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::time::{self, Duration}; + /// + /// use futures::{select, FutureExt}; + /// + /// async fn compute() -> String { + /// // Complex computation returning a `String` + /// # "hello".to_string() + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// select! { + /// _ = tx.closed().fuse() => { + /// // The receiver dropped, no need to do any further work + /// } + /// value = compute().fuse() => { + /// tx.send(value).unwrap() + /// } + /// } + /// }); + /// + /// // Wait for up to 10 seconds + /// let _ = time::timeout(Duration::from_secs(10), rx).await; + /// } + /// ``` + pub async fn closed(&mut self) { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_closed(cx)).await + } + + /// Returns `true` if the associated [`Receiver`] handle has been dropped. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// If `true` is returned, a call to `send` will always result in an error. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// } + /// ``` + pub fn is_closed(&self) -> bool { + let inner = self.inner.as_ref().unwrap(); + + let state = State::load(&inner.state, Acquire); + state.is_closed() + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.complete(); + } + } +} + +impl<T> Receiver<T> { + /// Prevents the associated [`Sender`] handle from sending a value. + /// + /// Any `send` operation which happens after calling `close` is guaranteed + /// to fail. After calling `close`, `Receiver::poll`] should be called to + /// receive a value if one was sent **before** the call to `close` + /// completed. + /// + /// This function is useful to perform a graceful shutdown and ensure that a + /// value will not be sent into the channel and never received. + /// + /// [`Sender`]: Sender + /// + /// # Examples + /// + /// Prevent a value from being sent + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// rx.close(); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// + /// match rx.try_recv() { + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// Receive a value sent **before** calling `close` + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(tx.send("will receive").is_ok()); + /// + /// rx.close(); + /// + /// let msg = rx.try_recv().unwrap(); + /// assert_eq!(msg, "will receive"); + /// } + /// ``` + pub fn close(&mut self) { + let inner = self.inner.as_ref().unwrap(); + inner.close(); + } + + /// Attempts to receive a value. + /// + /// If a pending value exists in the channel, it is returned. If no value + /// has been sent, the current task **will not** be registered for + /// future notification. + /// + /// This function is useful to call from outside the context of an + /// asynchronous task. + /// + /// # Return + /// + /// - `Ok(T)` if a value is pending in the channel. + /// - `Err(TryRecvError::Empty)` if no value has been sent yet. + /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending + /// a value. + /// + /// # Examples + /// + /// `try_recv` before a value is sent, then after. + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// match rx.try_recv() { + /// // The channel is currently empty + /// Err(TryRecvError::Empty) => {} + /// _ => unreachable!(), + /// } + /// + /// // Send a value + /// tx.send("hello").unwrap(); + /// + /// match rx.try_recv() { + /// Ok(value) => assert_eq!(value, "hello"), + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// `try_recv` when the sender dropped before sending a value + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel::<()>(); + /// + /// drop(tx); + /// + /// match rx.try_recv() { + /// // The channel will never receive a value. + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let result = if let Some(inner) = self.inner.as_ref() { + let state = State::load(&inner.state, Acquire); + + if state.is_complete() { + match unsafe { inner.consume_value() } { + Some(value) => Ok(value), + None => Err(TryRecvError::Closed), + } + } else if state.is_closed() { + Err(TryRecvError::Closed) + } else { + // Not ready, this does not clear `inner` + return Err(TryRecvError::Empty); + } + } else { + panic!("called after complete"); + }; + + self.inner = None; + result + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.close(); + } + } +} + +impl<T> Future for Receiver<T> { + type Output = Result<T, RecvError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // If `inner` is `None`, then `poll()` has already completed. + let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() { + ready!(inner.poll_recv(cx))? + } else { + panic!("called after complete"); + }; + + self.inner = None; + Ready(Ok(ret)) + } +} + +impl<T> Inner<T> { + fn complete(&self) -> bool { + let prev = State::set_complete(&self.state); + + if prev.is_closed() { + return false; + } + + if prev.is_rx_task_set() { + // TODO: Consume waker? + unsafe { + self.with_rx_task(Waker::wake_by_ref); + } + } + + true + } + + fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + // Keep track of task budget + ready!(crate::coop::poll_proceed(cx)); + + // Load the state + let mut state = State::load(&self.state, Acquire); + + if state.is_complete() { + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else if state.is_closed() { + Ready(Err(RecvError(()))) + } else { + if state.is_rx_task_set() { + let will_notify = unsafe { self.with_rx_task(|w| w.will_wake(cx.waker())) }; + + // Check if the task is still the same + if !will_notify { + // Unset the task + state = State::unset_rx_task(&self.state); + if state.is_complete() { + // Set the flag again so that the waker is released in drop + State::set_rx_task(&self.state); + + return match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + }; + } else { + unsafe { self.drop_rx_task() }; + } + } + } + + if !state.is_rx_task_set() { + // Attempt to set the task + unsafe { + self.set_rx_task(cx); + } + + // Update the state + state = State::set_rx_task(&self.state); + + if state.is_complete() { + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else { + Pending + } + } else { + Pending + } + } + } + + /// Called by `Receiver` to indicate that the value will never be received. + fn close(&self) { + let prev = State::set_closed(&self.state); + + if prev.is_tx_task_set() && !prev.is_complete() { + unsafe { + self.with_tx_task(Waker::wake_by_ref); + } + } + } + + /// Consumes the value. This function does not check `state`. + unsafe fn consume_value(&self) -> Option<T> { + self.value.with_mut(|ptr| (*ptr).take()) + } + + unsafe fn with_rx_task<F, R>(&self, f: F) -> R + where + F: FnOnce(&Waker) -> R, + { + self.rx_task.with(|ptr| { + let waker: *const Waker = (&*ptr).as_ptr(); + f(&*waker) + }) + } + + unsafe fn with_tx_task<F, R>(&self, f: F) -> R + where + F: FnOnce(&Waker) -> R, + { + self.tx_task.with(|ptr| { + let waker: *const Waker = (&*ptr).as_ptr(); + f(&*waker) + }) + } + + unsafe fn drop_rx_task(&self) { + self.rx_task.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.drop_in_place(); + }); + } + + unsafe fn drop_tx_task(&self) { + self.tx_task.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.drop_in_place(); + }); + } + + unsafe fn set_rx_task(&self, cx: &mut Context<'_>) { + self.rx_task.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.write(cx.waker().clone()); + }); + } + + unsafe fn set_tx_task(&self, cx: &mut Context<'_>) { + self.tx_task.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.write(cx.waker().clone()); + }); + } +} + +unsafe impl<T: Send> Send for Inner<T> {} +unsafe impl<T: Send> Sync for Inner<T> {} + +impl<T> Drop for Inner<T> { + fn drop(&mut self) { + let state = State(self.state.with_mut(|v| *v)); + + if state.is_rx_task_set() { + unsafe { + self.drop_rx_task(); + } + } + + if state.is_tx_task_set() { + unsafe { + self.drop_tx_task(); + } + } + } +} + +impl<T: fmt::Debug> fmt::Debug for Inner<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + use std::sync::atomic::Ordering::Relaxed; + + fmt.debug_struct("Inner") + .field("state", &State::load(&self.state, Relaxed)) + .finish() + } +} + +const RX_TASK_SET: usize = 0b00001; +const VALUE_SENT: usize = 0b00010; +const CLOSED: usize = 0b00100; +const TX_TASK_SET: usize = 0b01000; + +impl State { + fn new() -> State { + State(0) + } + + fn is_complete(self) -> bool { + self.0 & VALUE_SENT == VALUE_SENT + } + + fn set_complete(cell: &AtomicUsize) -> State { + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + let val = cell.fetch_or(VALUE_SENT, AcqRel); + State(val) + } + + fn is_rx_task_set(self) -> bool { + self.0 & RX_TASK_SET == RX_TASK_SET + } + + fn set_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(RX_TASK_SET, AcqRel); + State(val | RX_TASK_SET) + } + + fn unset_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!RX_TASK_SET, AcqRel); + State(val & !RX_TASK_SET) + } + + fn is_closed(self) -> bool { + self.0 & CLOSED == CLOSED + } + + fn set_closed(cell: &AtomicUsize) -> State { + // Acquire because we want all later writes (attempting to poll) to be + // ordered after this. + let val = cell.fetch_or(CLOSED, Acquire); + State(val) + } + + fn set_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(TX_TASK_SET, AcqRel); + State(val | TX_TASK_SET) + } + + fn unset_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!TX_TASK_SET, AcqRel); + State(val & !TX_TASK_SET) + } + + fn is_tx_task_set(self) -> bool { + self.0 & TX_TASK_SET == TX_TASK_SET + } + + fn as_usize(self) -> usize { + self.0 + } + + fn load(cell: &AtomicUsize, order: Ordering) -> State { + let val = cell.load(order); + State(val) + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("State") + .field("is_complete", &self.is_complete()) + .field("is_closed", &self.is_closed()) + .field("is_rx_task_set", &self.is_rx_task_set()) + .field("is_tx_task_set", &self.is_tx_task_set()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock.rs b/third_party/rust/tokio/src/sync/rwlock.rs new file mode 100644 index 0000000000..68cf710e84 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock.rs @@ -0,0 +1,294 @@ +use crate::coop::CoopFutureExt; +use crate::sync::batch_semaphore::{AcquireError, Semaphore}; +use std::cell::UnsafeCell; +use std::ops; + +#[cfg(not(loom))] +const MAX_READS: usize = 32; + +#[cfg(loom)] +const MAX_READS: usize = 10; + +/// An asynchronous reader-writer lock +/// +/// This type of lock allows a number of readers or at most one writer at any +/// point in time. The write portion of this lock typically allows modification +/// of the underlying data (exclusive access) and the read portion of this lock +/// typically allows for read-only access (shared access). +/// +/// In comparison, a [`Mutex`] does not distinguish between readers or writers +/// that acquire the lock, therefore causing any tasks waiting for the lock to +/// become available to yield. An `RwLock` will allow any number of readers to +/// acquire the lock as long as a writer is not holding the lock. +/// +/// The priority policy of Tokio's read-write lock is _fair_ (or +/// [_write-preferring_]), in order to ensure that readers cannot starve +/// writers. Fairness is ensured using a first-in, first-out queue for the tasks +/// awaiting the lock; if a task that wishes to acquire the write lock is at the +/// head of the queue, read locks will not be given out until the write lock has +/// been released. This is in contrast to the Rust standard library's +/// `std::sync::RwLock`, where the priority policy is dependent on the +/// operating system's implementation. +/// +/// The type parameter `T` represents the data that this lock protects. It is +/// required that `T` satisfies [`Send`] to be shared across threads. The RAII guards +/// returned from the locking methods implement [`Deref`](https://doc.rust-lang.org/std/ops/trait.Deref.html) +/// (and [`DerefMut`](https://doc.rust-lang.org/std/ops/trait.DerefMut.html) +/// for the `write` methods) to allow access to the content of the lock. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::RwLock; +/// +/// #[tokio::main] +/// async fn main() { +/// let lock = RwLock::new(5); +/// +/// // many reader locks can be held at once +/// { +/// let r1 = lock.read().await; +/// let r2 = lock.read().await; +/// assert_eq!(*r1, 5); +/// assert_eq!(*r2, 5); +/// } // read locks are dropped at this point +/// +/// // only one write lock may be held, however +/// { +/// let mut w = lock.write().await; +/// *w += 1; +/// assert_eq!(*w, 6); +/// } // write lock is dropped here +/// } +/// ``` +/// +/// [`Mutex`]: struct@super::Mutex +/// [`RwLock`]: struct@RwLock +/// [`RwLockReadGuard`]: struct@RwLockReadGuard +/// [`RwLockWriteGuard`]: struct@RwLockWriteGuard +/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html +/// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies +#[derive(Debug)] +pub struct RwLock<T> { + //semaphore to coordinate read and write access to T + s: Semaphore, + + //inner data T + c: UnsafeCell<T>, +} + +/// RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read`] method on +/// [`RwLock`]. +/// +/// [`read`]: method@RwLock::read +#[derive(Debug)] +pub struct RwLockReadGuard<'a, T> { + permit: ReleasingPermit<'a, T>, + lock: &'a RwLock<T>, +} + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write`] and method +/// on [`RwLock`]. +/// +/// [`write`]: method@RwLock::write +/// [`RwLock`]: struct@RwLock +#[derive(Debug)] +pub struct RwLockWriteGuard<'a, T> { + permit: ReleasingPermit<'a, T>, + lock: &'a RwLock<T>, +} + +// Wrapper arround Permit that releases on Drop +#[derive(Debug)] +struct ReleasingPermit<'a, T> { + num_permits: u16, + lock: &'a RwLock<T>, +} + +impl<'a, T> ReleasingPermit<'a, T> { + async fn acquire( + lock: &'a RwLock<T>, + num_permits: u16, + ) -> Result<ReleasingPermit<'a, T>, AcquireError> { + lock.s.acquire(num_permits).cooperate().await?; + Ok(Self { num_permits, lock }) + } +} + +impl<'a, T> Drop for ReleasingPermit<'a, T> { + fn drop(&mut self) { + self.lock.s.release(self.num_permits as usize); + } +} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_sync<T: Sync>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + + check_send::<RwLock<u32>>(); + check_sync::<RwLock<u32>>(); + check_unpin::<RwLock<u32>>(); + + check_sync::<RwLockReadGuard<'_, u32>>(); + check_unpin::<RwLockReadGuard<'_, u32>>(); + + check_sync::<RwLockWriteGuard<'_, u32>>(); + check_unpin::<RwLockWriteGuard<'_, u32>>(); + + let rwlock = RwLock::new(0); + check_send_sync_val(rwlock.read()); + check_send_sync_val(rwlock.write()); +} + +// As long as T: Send + Sync, it's fine to send and share RwLock<T> between threads. +// If T were not Send, sending and sharing a RwLock<T> would be bad, since you can access T through +// RwLock<T>. +unsafe impl<T> Send for RwLock<T> where T: Send {} +unsafe impl<T> Sync for RwLock<T> where T: Send + Sync {} +unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {} +unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {} + +impl<T> RwLock<T> { + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// let lock = RwLock::new(5); + /// ``` + pub fn new(value: T) -> RwLock<T> { + RwLock { + c: UnsafeCell::new(value), + s: Semaphore::new(MAX_READS), + } + } + + /// Locks this rwlock with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no more writers which + /// hold the lock. There may be other readers currently inside the lock when + /// this method returns. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let r = c_lock.read().await; + /// assert_eq!(*r, 1); + /// }).await.expect("The spawned task has paniced"); + /// + /// // Drop the guard after the spawned task finishes. + /// drop(n); + ///} + /// ``` + pub async fn read(&self) -> RwLockReadGuard<'_, T> { + let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + RwLockReadGuard { lock: self, permit } + } + + /// Locks this rwlock with exclusive write access, causing the current task + /// to yield until the lock has been acquired. + /// + /// This function will not return while other writers or other readers + /// currently have access to the lock. + /// + /// Returns an RAII guard which will drop the write access of this rwlock + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.write().await; + /// *n = 2; + ///} + /// ``` + pub async fn write(&self) -> RwLockWriteGuard<'_, T> { + let permit = ReleasingPermit::acquire(self, MAX_READS as u16) + .await + .unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + RwLockWriteGuard { lock: self, permit } + } + + /// Consumes the lock, returning the underlying data. + pub fn into_inner(self) -> T { + self.c.into_inner() + } +} + +impl<T> ops::Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.lock.c.get() } + } +} + +impl<T> ops::Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.lock.c.get() } + } +} + +impl<T> ops::DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<T> From<T> for RwLock<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T> Default for RwLock<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} diff --git a/third_party/rust/tokio/src/sync/semaphore.rs b/third_party/rust/tokio/src/sync/semaphore.rs new file mode 100644 index 0000000000..4cce7e8f5b --- /dev/null +++ b/third_party/rust/tokio/src/sync/semaphore.rs @@ -0,0 +1,105 @@ +use super::batch_semaphore as ll; // low level implementation +use crate::coop::CoopFutureExt; + +/// Counting semaphore performing asynchronous permit aquisition. +/// +/// A semaphore maintains a set of permits. Permits are used to synchronize +/// access to a shared resource. A semaphore differs from a mutex in that it +/// can allow more than one concurrent caller to access the shared resource at a +/// time. +/// +/// When `acquire` is called and the semaphore has remaining permits, the +/// function immediately returns a permit. However, if no remaining permits are +/// available, `acquire` (asynchronously) waits until an outstanding permit is +/// dropped. At this point, the freed permit is assigned to the caller. +#[derive(Debug)] +pub struct Semaphore { + /// The low level semaphore + ll_sem: ll::Semaphore, +} + +/// A permit from the semaphore +#[must_use] +#[derive(Debug)] +pub struct SemaphorePermit<'a> { + sem: &'a Semaphore, + permits: u16, +} + +/// Error returned from the [`Semaphore::try_acquire`] function. +/// +/// A `try_acquire` operation can only fail if the semaphore has no available +/// permits. +/// +/// [`Semaphore::try_acquire`]: Semaphore::try_acquire +#[derive(Debug)] +pub struct TryAcquireError(()); + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + check_unpin::<Semaphore>(); + check_unpin::<SemaphorePermit<'_>>(); + check_send_sync::<Semaphore>(); + + let semaphore = Semaphore::new(0); + check_send_sync_val(semaphore.acquire()); +} + +impl Semaphore { + /// Creates a new semaphore with the initial number of permits + pub fn new(permits: usize) -> Self { + Self { + ll_sem: ll::Semaphore::new(permits), + } + } + + /// Returns the current number of available permits + pub fn available_permits(&self) -> usize { + self.ll_sem.available_permits() + } + + /// Adds `n` new permits to the semaphore. + pub fn add_permits(&self, n: usize) { + self.ll_sem.release(n); + } + + /// Acquires permit from the semaphore + pub async fn acquire(&self) -> SemaphorePermit<'_> { + self.ll_sem.acquire(1).cooperate().await.unwrap(); + SemaphorePermit { + sem: &self, + permits: 1, + } + } + + /// Tries to acquire a permit form the semaphore + pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(SemaphorePermit { + sem: self, + permits: 1, + }), + Err(_) => Err(TryAcquireError(())), + } + } +} + +impl<'a> SemaphorePermit<'a> { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + +impl<'a> Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/third_party/rust/tokio/src/sync/semaphore_ll.rs b/third_party/rust/tokio/src/sync/semaphore_ll.rs new file mode 100644 index 0000000000..0bdc4e2761 --- /dev/null +++ b/third_party/rust/tokio/src/sync/semaphore_ll.rs @@ -0,0 +1,1220 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! Thread-safe, asynchronous counting semaphore. +//! +//! A `Semaphore` instance holds a set of permits. Permits are used to +//! synchronize access to a shared resource. +//! +//! Before accessing the shared resource, callers acquire a permit from the +//! semaphore. Once the permit is acquired, the caller then enters the critical +//! section. If no permits are available, then acquiring the semaphore returns +//! `Pending`. The task is woken once a permit becomes available. + +use crate::loom::cell::UnsafeCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; + +use std::cmp; +use std::fmt; +use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; +use std::usize; + +/// Futures-aware semaphore. +pub(crate) struct Semaphore { + /// Tracks both the waiter queue tail pointer and the number of remaining + /// permits. + state: AtomicUsize, + + /// waiter queue head pointer. + head: UnsafeCell<NonNull<Waiter>>, + + /// Coordinates access to the queue head. + rx_lock: AtomicUsize, + + /// Stub waiter node used as part of the MPSC channel algorithm. + stub: Box<Waiter>, +} + +/// A semaphore permit +/// +/// Tracks the lifecycle of a semaphore permit. +/// +/// An instance of `Permit` is intended to be used with a **single** instance of +/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore +/// instances will result in unexpected behavior. +/// +/// `Permit` does **not** release the permit back to the semaphore on drop. It +/// is the user's responsibility to ensure that `Permit::release` is called +/// before dropping the permit. +#[derive(Debug)] +pub(crate) struct Permit { + waiter: Option<Box<Waiter>>, + state: PermitState, +} + +/// Error returned by `Permit::poll_acquire`. +#[derive(Debug)] +pub(crate) struct AcquireError(()); + +/// Error returned by `Permit::try_acquire`. +#[derive(Debug)] +pub(crate) enum TryAcquireError { + Closed, + NoPermits, +} + +/// Node used to notify the semaphore waiter when permit is available. +#[derive(Debug)] +struct Waiter { + /// Stores waiter state. + /// + /// See `WaiterState` for more details. + state: AtomicUsize, + + /// Task to wake when a permit is made available. + waker: AtomicWaker, + + /// Next pointer in the queue of waiting senders. + next: AtomicPtr<Waiter>, +} + +/// Semaphore state +/// +/// The 2 low bits track the modes. +/// +/// - Closed +/// - Full +/// +/// When not full, the rest of the `usize` tracks the total number of messages +/// in the channel. When full, the rest of the `usize` is a pointer to the tail +/// of the "waiting senders" queue. +#[derive(Copy, Clone)] +struct SemState(usize); + +/// Permit state +#[derive(Debug, Copy, Clone)] +enum PermitState { + /// Currently waiting for permits to be made available and assigned to the + /// waiter. + Waiting(u16), + + /// The number of acquired permits + Acquired(u16), +} + +/// State for an individual waker node +#[derive(Debug, Copy, Clone)] +struct WaiterState(usize); + +/// Waiter node is in the semaphore queue +const QUEUED: usize = 0b001; + +/// Semaphore has been closed, no more permits will be issued. +const CLOSED: usize = 0b10; + +/// The permit that owns the `Waiter` dropped. +const DROPPED: usize = 0b100; + +/// Represents "one requested permit" in the waiter state +const PERMIT_ONE: usize = 0b1000; + +/// Masks the waiter state to only contain bits tracking number of requested +/// permits. +const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); + +/// How much to shift a permit count to pack it into the waker state +const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); + +/// Flag differentiating between available permits and waiter pointers. +/// +/// If we assume pointers are properly aligned, then the least significant bit +/// will always be zero. So, we use that bit to track if the value represents a +/// number. +const NUM_FLAG: usize = 0b01; + +/// Signal the semaphore is closed +const CLOSED_FLAG: usize = 0b10; + +/// Maximum number of permits a semaphore can manage +const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; + +/// When representing "numbers", the state has to be shifted this much (to get +/// rid of the flag bit). +const NUM_SHIFT: usize = 2; + +// ===== impl Semaphore ===== + +impl Semaphore { + /// Creates a new semaphore with the initial number of permits + /// + /// # Panics + /// + /// Panics if `permits` is zero. + pub(crate) fn new(permits: usize) -> Semaphore { + let stub = Box::new(Waiter::new()); + let ptr = NonNull::from(&*stub); + + // Allocations are aligned + debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); + + let state = SemState::new(permits, &stub); + + Semaphore { + state: AtomicUsize::new(state.to_usize()), + head: UnsafeCell::new(ptr), + rx_lock: AtomicUsize::new(0), + stub, + } + } + + /// Returns the current number of available permits + pub(crate) fn available_permits(&self) -> usize { + let curr = SemState(self.state.load(Acquire)); + curr.available_permits() + } + + /// Tries to acquire the requested number of permits, registering the waiter + /// if not enough permits are available. + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: u16, + permit: &mut Permit, + ) -> Poll<Result<(), AcquireError>> { + self.poll_acquire2(num_permits, || { + let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); + + waiter.waker.register_by_ref(cx.waker()); + + Some(NonNull::from(&**waiter)) + }) + } + + fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { + match self.poll_acquire2(num_permits, || None) { + Poll::Ready(res) => res.map_err(to_try_acquire), + Poll::Pending => Err(TryAcquireError::NoPermits), + } + } + + /// Polls for a permit + /// + /// Tries to acquire available permits first. If unable to acquire a + /// sufficient number of permits, the caller's waiter is pushed onto the + /// semaphore's wait queue. + fn poll_acquire2<F>( + &self, + num_permits: u16, + mut get_waiter: F, + ) -> Poll<Result<(), AcquireError>> + where + F: FnMut() -> Option<NonNull<Waiter>>, + { + let num_permits = num_permits as usize; + + // Load the current state + let mut curr = SemState(self.state.load(Acquire)); + + // Saves a ref to the waiter node + let mut maybe_waiter: Option<NonNull<Waiter>> = None; + + /// Used in branches where we attempt to push the waiter into the wait + /// queue but fail due to permits becoming available or the wait queue + /// transitioning to "closed". In this case, the waiter must be + /// transitioned back to the "idle" state. + macro_rules! revert_to_idle { + () => { + if let Some(waiter) = maybe_waiter { + unsafe { waiter.as_ref() }.revert_to_idle(); + } + }; + } + + loop { + let mut next = curr; + + if curr.is_closed() { + revert_to_idle!(); + return Ready(Err(AcquireError::closed())); + } + + let acquired = next.acquire_permits(num_permits, &self.stub); + + if !acquired { + // There are not enough available permits to satisfy the + // request. The permit transitions to a waiting state. + debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); + + if let Some(waiter) = maybe_waiter.as_ref() { + // Safety: the caller owns the waiter. + let w = unsafe { waiter.as_ref() }; + w.set_permits_to_acquire(num_permits - curr.available_permits()); + } else { + // Get the waiter for the permit. + if let Some(waiter) = get_waiter() { + // Safety: the caller owns the waiter. + let w = unsafe { waiter.as_ref() }; + + // If there are any currently available permits, the + // waiter acquires those immediately and waits for the + // remaining permits to become available. + if !w.to_queued(num_permits - curr.available_permits()) { + // The node is alrady queued, there is no further work + // to do. + return Pending; + } + + maybe_waiter = Some(waiter); + } else { + // No waiter, this indicates the caller does not wish to + // "wait", so there is nothing left to do. + return Pending; + } + } + + next.set_waiter(maybe_waiter.unwrap()); + } + + debug_assert_ne!(curr.0, 0); + debug_assert_ne!(next.0, 0); + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => { + if acquired { + // Successfully acquire permits **without** queuing the + // waiter node. The waiter node is not currently in the + // queue. + revert_to_idle!(); + return Ready(Ok(())); + } else { + // The node is pushed into the queue, the final step is + // to set the node's "next" pointer to return the wait + // queue into a consistent state. + + let prev_waiter = + curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); + + let waiter = maybe_waiter.unwrap(); + + // Link the nodes. + // + // Safety: the mpsc algorithm guarantees the old tail of + // the queue is not removed from the queue during the + // push process. + unsafe { + prev_waiter.as_ref().store_next(waiter); + } + + return Pending; + } + } + Err(actual) => { + curr = SemState(actual); + } + } + } + } + + /// Closes the semaphore. This prevents the semaphore from issuing new + /// permits and notifies all pending waiters. + pub(crate) fn close(&self) { + // Acquire the `rx_lock`, setting the "closed" flag on the lock. + let prev = self.rx_lock.fetch_or(1, AcqRel); + + if prev != 0 { + // Another thread has the lock and will be responsible for notifying + // pending waiters. + return; + } + + self.add_permits_locked(0, true); + } + + /// Adds `n` new permits to the semaphore. + pub(crate) fn add_permits(&self, n: usize) { + if n == 0 { + return; + } + + // TODO: Handle overflow. A panic is not sufficient, the process must + // abort. + let prev = self.rx_lock.fetch_add(n << 1, AcqRel); + + if prev != 0 { + // Another thread has the lock and will be responsible for notifying + // pending waiters. + return; + } + + self.add_permits_locked(n, false); + } + + fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { + while rem > 0 || closed { + if closed { + SemState::fetch_set_closed(&self.state, AcqRel); + } + + // Release the permits and notify + self.add_permits_locked2(rem, closed); + + let n = rem << 1; + + let actual = if closed { + let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); + closed = false; + actual + } else { + let actual = self.rx_lock.fetch_sub(n, AcqRel); + closed = actual & 1 == 1; + actual + }; + + rem = (actual >> 1) - rem; + } + } + + /// Releases a specific amount of permits to the semaphore + /// + /// This function is called by `add_permits` after the add lock has been + /// acquired. + fn add_permits_locked2(&self, mut n: usize, closed: bool) { + // If closing the semaphore, we want to drain the entire queue. The + // number of permits being assigned doesn't matter. + if closed { + n = usize::MAX; + } + + 'outer: while n > 0 { + unsafe { + let mut head = self.head.with(|head| *head); + let mut next_ptr = head.as_ref().next.load(Acquire); + + let stub = self.stub(); + + if head == stub { + // The stub node indicates an empty queue. Any remaining + // permits get assigned back to the semaphore. + let next = match NonNull::new(next_ptr) { + Some(next) => next, + None => { + // This loop is not part of the standard intrusive mpsc + // channel algorithm. This is where we atomically pop + // the last task and add `n` to the remaining capacity. + // + // This modification to the pop algorithm works because, + // at this point, we have not done any work (only done + // reading). We have a *pretty* good idea that there is + // no concurrent pusher. + // + // The capacity is then atomically added by doing an + // AcqRel CAS on `state`. The `state` cell is the + // linchpin of the algorithm. + // + // By successfully CASing `head` w/ AcqRel, we ensure + // that, if any thread was racing and entered a push, we + // see that and abort pop, retrying as it is + // "inconsistent". + let mut curr = SemState::load(&self.state, Acquire); + + loop { + if curr.has_waiter(&self.stub) { + // A waiter is being added concurrently. + // This is the MPSC queue's "inconsistent" + // state and we must loop and try again. + thread::yield_now(); + continue 'outer; + } + + // If closing, nothing more to do. + if closed { + debug_assert!(curr.is_closed(), "state = {:?}", curr); + return; + } + + let mut next = curr; + next.release_permits(n, &self.stub); + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return, + Err(actual) => { + curr = SemState(actual); + } + } + } + } + }; + + self.head.with_mut(|head| *head = next); + head = next; + next_ptr = next.as_ref().next.load(Acquire); + } + + // `head` points to a waiter assign permits to the waiter. If + // all requested permits are satisfied, then we can continue, + // otherwise the node stays in the wait queue. + if !head.as_ref().assign_permits(&mut n, closed) { + assert_eq!(n, 0); + return; + } + + if let Some(next) = NonNull::new(next_ptr) { + self.head.with_mut(|head| *head = next); + + self.remove_queued(head, closed); + continue 'outer; + } + + let state = SemState::load(&self.state, Acquire); + + // This must always be a pointer as the wait list is not empty. + let tail = state.waiter().unwrap(); + + if tail != head { + // Inconsistent + thread::yield_now(); + continue 'outer; + } + + self.push_stub(closed); + + next_ptr = head.as_ref().next.load(Acquire); + + if let Some(next) = NonNull::new(next_ptr) { + self.head.with_mut(|head| *head = next); + + self.remove_queued(head, closed); + continue 'outer; + } + + // Inconsistent state, loop + thread::yield_now(); + } + } + } + + /// The wait node has had all of its permits assigned and has been removed + /// from the wait queue. + /// + /// Attempt to remove the QUEUED bit from the node. If additional permits + /// are concurrently requested, the node must be pushed back into the wait + /// queued. + fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) { + let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); + + loop { + if curr.is_dropped() { + // The Permit dropped, it is on us to release the memory + let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; + return; + } + + // The node is removed from the queue. We attempt to unset the + // queued bit, but concurrently the waiter has requested more + // permits. When the waiter requested more permits, it saw the + // queued bit set so took no further action. This requires us to + // push the node back into the queue. + if curr.permits_to_acquire() > 0 { + // More permits are requested. The waiter must be re-queued + unsafe { + self.push_waiter(waiter, closed); + } + return; + } + + let mut next = curr; + next.unset_queued(); + + let w = unsafe { waiter.as_ref() }; + + match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return, + Err(actual) => { + curr = WaiterState(actual); + } + } + } + } + + unsafe fn push_stub(&self, closed: bool) { + self.push_waiter(self.stub(), closed); + } + + unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) { + // Set the next pointer. This does not require an atomic operation as + // this node is not accessible. The write will be flushed with the next + // operation + waiter.as_ref().next.store(ptr::null_mut(), Relaxed); + + // Update the tail to point to the new node. We need to see the previous + // node in order to update the next pointer as well as release `task` + // to any other threads calling `push`. + let next = SemState::new_ptr(waiter, closed); + let prev = SemState(self.state.swap(next.0, AcqRel)); + + debug_assert_eq!(closed, prev.is_closed()); + + // This function is only called when there are pending tasks. Because of + // this, the state must *always* be in pointer mode. + let prev = prev.waiter().unwrap(); + + // No cycles plz + debug_assert_ne!(prev, waiter); + + // Release `task` to the consume end. + prev.as_ref().next.store(waiter.as_ptr(), Release); + } + + fn stub(&self) -> NonNull<Waiter> { + unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } + } +} + +impl Drop for Semaphore { + fn drop(&mut self) { + self.close(); + } +} + +impl fmt::Debug for Semaphore { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Semaphore") + .field("state", &SemState::load(&self.state, Relaxed)) + .field("head", &self.head.with(|ptr| ptr)) + .field("rx_lock", &self.rx_lock.load(Relaxed)) + .field("stub", &self.stub) + .finish() + } +} + +unsafe impl Send for Semaphore {} +unsafe impl Sync for Semaphore {} + +// ===== impl Permit ===== + +impl Permit { + /// Creates a new `Permit`. + /// + /// The permit begins in the "unacquired" state. + pub(crate) fn new() -> Permit { + use PermitState::Acquired; + + Permit { + waiter: None, + state: Acquired(0), + } + } + + /// Returns `true` if the permit has been acquired + #[allow(dead_code)] // may be used later + pub(crate) fn is_acquired(&self) -> bool { + match self.state { + PermitState::Acquired(num) if num > 0 => true, + _ => false, + } + } + + /// Tries to acquire the permit. If no permits are available, the current task + /// is notified once a new permit becomes available. + pub(crate) fn poll_acquire( + &mut self, + cx: &mut Context<'_>, + num_permits: u16, + semaphore: &Semaphore, + ) -> Poll<Result<(), AcquireError>> { + use std::cmp::Ordering::*; + use PermitState::*; + + match self.state { + Waiting(requested) => { + // There must be a waiter + let waiter = self.waiter.as_ref().unwrap(); + + match requested.cmp(&num_permits) { + Less => { + let delta = num_permits - requested; + + // Request additional permits. If the waiter has been + // dequeued, it must be re-queued. + if !waiter.try_inc_permits_to_acquire(delta as usize) { + let waiter = NonNull::from(&**waiter); + + // Ignore the result. The check for + // `permits_to_acquire()` will converge the state as + // needed + let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; + } + + self.state = Waiting(num_permits); + } + Greater => { + let delta = requested - num_permits; + let to_release = waiter.try_dec_permits_to_acquire(delta as usize); + + semaphore.add_permits(to_release); + self.state = Waiting(num_permits); + } + Equal => {} + } + + if waiter.permits_to_acquire()? == 0 { + self.state = Acquired(requested); + return Ready(Ok(())); + } + + waiter.waker.register_by_ref(cx.waker()); + + if waiter.permits_to_acquire()? == 0 { + self.state = Acquired(requested); + return Ready(Ok(())); + } + + Pending + } + Acquired(acquired) => { + if acquired >= num_permits { + Ready(Ok(())) + } else { + match semaphore.poll_acquire(cx, num_permits - acquired, self)? { + Ready(()) => { + self.state = Acquired(num_permits); + Ready(Ok(())) + } + Pending => { + self.state = Waiting(num_permits); + Pending + } + } + } + } + } + } + + /// Tries to acquire the permit. + pub(crate) fn try_acquire( + &mut self, + num_permits: u16, + semaphore: &Semaphore, + ) -> Result<(), TryAcquireError> { + use PermitState::*; + + match self.state { + Waiting(requested) => { + // There must be a waiter + let waiter = self.waiter.as_ref().unwrap(); + + if requested > num_permits { + let delta = requested - num_permits; + let to_release = waiter.try_dec_permits_to_acquire(delta as usize); + + semaphore.add_permits(to_release); + self.state = Waiting(num_permits); + } + + let res = waiter.permits_to_acquire().map_err(to_try_acquire)?; + + if res == 0 { + if requested < num_permits { + // Try to acquire the additional permits + semaphore.try_acquire(num_permits - requested)?; + } + + self.state = Acquired(num_permits); + Ok(()) + } else { + Err(TryAcquireError::NoPermits) + } + } + Acquired(acquired) => { + if acquired < num_permits { + semaphore.try_acquire(num_permits - acquired)?; + self.state = Acquired(num_permits); + } + + Ok(()) + } + } + } + + /// Releases a permit back to the semaphore + pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) { + let n = self.forget(n); + semaphore.add_permits(n as usize); + } + + /// Forgets the permit **without** releasing it back to the semaphore. + /// + /// After calling `forget`, `poll_acquire` is able to acquire new permit + /// from the sempahore. + /// + /// Repeatedly calling `forget` without associated calls to `add_permit` + /// will result in the semaphore losing all permits. + /// + /// Will forget **at most** the number of acquired permits. This number is + /// returned. + pub(crate) fn forget(&mut self, n: u16) -> u16 { + use PermitState::*; + + match self.state { + Waiting(requested) => { + let n = cmp::min(n, requested); + + // Decrement + let acquired = self + .waiter + .as_ref() + .unwrap() + .try_dec_permits_to_acquire(n as usize) as u16; + + if n == requested { + self.state = Acquired(0); + } else if acquired == requested - n { + self.state = Waiting(acquired); + } else { + self.state = Waiting(requested - n); + } + + acquired + } + Acquired(acquired) => { + let n = cmp::min(n, acquired); + self.state = Acquired(acquired - n); + n + } + } + } +} + +impl Default for Permit { + fn default() -> Self { + Self::new() + } +} + +impl Drop for Permit { + fn drop(&mut self) { + if let Some(waiter) = self.waiter.take() { + // Set the dropped flag + let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel)); + + if state.is_queued() { + // The waiter is stored in the queue. The semaphore will drop it + std::mem::forget(waiter); + } + } + } +} + +// ===== impl AcquireError ==== + +impl AcquireError { + fn closed() -> AcquireError { + AcquireError(()) + } +} + +fn to_try_acquire(_: AcquireError) -> TryAcquireError { + TryAcquireError::Closed +} + +impl fmt::Display for AcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "semaphore closed") + } +} + +impl std::error::Error for AcquireError {} + +// ===== impl TryAcquireError ===== + +impl TryAcquireError { + /// Returns `true` if the error was caused by a closed semaphore. + pub(crate) fn is_closed(&self) -> bool { + match self { + TryAcquireError::Closed => true, + _ => false, + } + } + + /// Returns `true` if the error was caused by calling `try_acquire` on a + /// semaphore with no available permits. + pub(crate) fn is_no_permits(&self) -> bool { + match self { + TryAcquireError::NoPermits => true, + _ => false, + } + } +} + +impl fmt::Display for TryAcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + } + } +} + +impl std::error::Error for TryAcquireError {} + +// ===== impl Waiter ===== + +impl Waiter { + fn new() -> Waiter { + Waiter { + state: AtomicUsize::new(0), + waker: AtomicWaker::new(), + next: AtomicPtr::new(ptr::null_mut()), + } + } + + fn permits_to_acquire(&self) -> Result<usize, AcquireError> { + let state = WaiterState(self.state.load(Acquire)); + + if state.is_closed() { + Err(AcquireError(())) + } else { + Ok(state.permits_to_acquire()) + } + } + + /// Only increments the number of permits *if* the waiter is currently + /// queued. + /// + /// # Returns + /// + /// `true` if the number of permits to acquire has been incremented. `false` + /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`. + fn try_inc_permits_to_acquire(&self, n: usize) -> bool { + let mut curr = WaiterState(self.state.load(Acquire)); + + loop { + if !curr.is_queued() { + assert_eq!(0, curr.permits_to_acquire()); + return false; + } + + let mut next = curr; + next.set_permits_to_acquire(n + curr.permits_to_acquire()); + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return true, + Err(actual) => curr = WaiterState(actual), + } + } + } + + /// Try to decrement the number of permits to acquire. This returns the + /// actual number of permits that were decremented. The delta betweeen `n` + /// and the return has been assigned to the permit and the caller must + /// assign these back to the semaphore. + fn try_dec_permits_to_acquire(&self, n: usize) -> usize { + let mut curr = WaiterState(self.state.load(Acquire)); + + loop { + if !curr.is_queued() { + assert_eq!(0, curr.permits_to_acquire()); + } + + let delta = cmp::min(n, curr.permits_to_acquire()); + let rem = curr.permits_to_acquire() - delta; + + let mut next = curr; + next.set_permits_to_acquire(rem); + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return n - delta, + Err(actual) => curr = WaiterState(actual), + } + } + } + + /// Store the number of remaining permits needed to satisfy the waiter and + /// transition to the "QUEUED" state. + /// + /// # Returns + /// + /// `true` if the `QUEUED` bit was set as part of the transition. + fn to_queued(&self, num_permits: usize) -> bool { + let mut curr = WaiterState(self.state.load(Acquire)); + + // The waiter should **not** be waiting for any permits. + debug_assert_eq!(curr.permits_to_acquire(), 0); + + loop { + let mut next = curr; + next.set_permits_to_acquire(num_permits); + next.set_queued(); + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => { + if curr.is_queued() { + return false; + } else { + // Make sure the next pointer is null + self.next.store(ptr::null_mut(), Relaxed); + return true; + } + } + Err(actual) => curr = WaiterState(actual), + } + } + } + + /// Set the number of permits to acquire. + /// + /// This function is only called when the waiter is being inserted into the + /// wait queue. Because of this, there are no concurrent threads that can + /// modify the state and using `store` is safe. + fn set_permits_to_acquire(&self, num_permits: usize) { + debug_assert!(WaiterState(self.state.load(Acquire)).is_queued()); + + let mut state = WaiterState(QUEUED); + state.set_permits_to_acquire(num_permits); + + self.state.store(state.0, Release); + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize, closed: bool) -> bool { + let mut curr = WaiterState(self.state.load(Acquire)); + + loop { + let mut next = curr; + + // Number of permits to assign to this waiter + let assign = cmp::min(curr.permits_to_acquire(), *n); + + // Assign the permits + next.set_permits_to_acquire(curr.permits_to_acquire() - assign); + + if closed { + next.set_closed(); + } + + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => { + // Update `n` + *n -= assign; + + if next.permits_to_acquire() == 0 { + if curr.permits_to_acquire() > 0 { + self.waker.wake(); + } + + return true; + } else { + return false; + } + } + Err(actual) => curr = WaiterState(actual), + } + } + } + + fn revert_to_idle(&self) { + // An idle node is not waiting on any permits + self.state.store(0, Relaxed); + } + + fn store_next(&self, next: NonNull<Waiter>) { + self.next.store(next.as_ptr(), Release); + } +} + +// ===== impl SemState ===== + +impl SemState { + /// Returns a new default `State` value. + fn new(permits: usize, stub: &Waiter) -> SemState { + assert!(permits <= MAX_PERMITS); + + if permits > 0 { + SemState((permits << NUM_SHIFT) | NUM_FLAG) + } else { + SemState(stub as *const _ as usize) + } + } + + /// Returns a `State` tracking `ptr` as the tail of the queue. + fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState { + let mut val = tail.as_ptr() as usize; + + if closed { + val |= CLOSED_FLAG; + } + + SemState(val) + } + + /// Returns the amount of remaining capacity + fn available_permits(self) -> usize { + if !self.has_available_permits() { + return 0; + } + + self.0 >> NUM_SHIFT + } + + /// Returns `true` if the state has permits that can be claimed by a waiter. + fn has_available_permits(self) -> bool { + self.0 & NUM_FLAG == NUM_FLAG + } + + fn has_waiter(self, stub: &Waiter) -> bool { + !self.has_available_permits() && !self.is_stub(stub) + } + + /// Tries to atomically acquire specified number of permits. + /// + /// # Return + /// + /// Returns `true` if the specified number of permits were acquired, `false` + /// otherwise. Returning false does not mean that there are no more + /// available permits. + fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool { + debug_assert!(num > 0); + + if self.available_permits() < num { + return false; + } + + debug_assert!(self.waiter().is_none()); + + self.0 -= num << NUM_SHIFT; + + if self.0 == NUM_FLAG { + // Set the state to the stub pointer. + self.0 = stub as *const _ as usize; + } + + true + } + + /// Releases permits + /// + /// Returns `true` if the permits were accepted. + fn release_permits(&mut self, permits: usize, stub: &Waiter) { + debug_assert!(permits > 0); + + if self.is_stub(stub) { + self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG); + return; + } + + debug_assert!(self.has_available_permits()); + + self.0 += permits << NUM_SHIFT; + } + + fn is_waiter(self) -> bool { + self.0 & NUM_FLAG == 0 + } + + /// Returns the waiter, if one is set. + fn waiter(self) -> Option<NonNull<Waiter>> { + if self.is_waiter() { + let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored"); + + Some(waiter) + } else { + None + } + } + + /// Assumes `self` represents a pointer + fn as_ptr(self) -> *mut Waiter { + (self.0 & !CLOSED_FLAG) as *mut Waiter + } + + /// Sets to a pointer to a waiter. + /// + /// This can only be done from the full state. + fn set_waiter(&mut self, waiter: NonNull<Waiter>) { + let waiter = waiter.as_ptr() as usize; + debug_assert!(!self.is_closed()); + + self.0 = waiter; + } + + fn is_stub(self, stub: &Waiter) -> bool { + self.as_ptr() as usize == stub as *const _ as usize + } + + /// Loads the state from an AtomicUsize. + fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState { + let value = cell.load(ordering); + SemState(value) + } + + fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState { + let value = cell.fetch_or(CLOSED_FLAG, ordering); + SemState(value) + } + + fn is_closed(self) -> bool { + self.0 & CLOSED_FLAG == CLOSED_FLAG + } + + /// Converts the state into a `usize` representation. + fn to_usize(self) -> usize { + self.0 + } +} + +impl fmt::Debug for SemState { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut fmt = fmt.debug_struct("SemState"); + + if self.is_waiter() { + fmt.field("state", &"<waiter>"); + } else { + fmt.field("permits", &self.available_permits()); + } + + fmt.finish() + } +} + +// ===== impl WaiterState ===== + +impl WaiterState { + fn permits_to_acquire(self) -> usize { + self.0 >> PERMIT_SHIFT + } + + fn set_permits_to_acquire(&mut self, val: usize) { + self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK) + } + + fn is_queued(self) -> bool { + self.0 & QUEUED == QUEUED + } + + fn set_queued(&mut self) { + self.0 |= QUEUED; + } + + fn is_closed(self) -> bool { + self.0 & CLOSED == CLOSED + } + + fn set_closed(&mut self) { + self.0 |= CLOSED; + } + + fn unset_queued(&mut self) { + assert!(self.is_queued()); + self.0 -= QUEUED; + } + + fn is_dropped(self) -> bool { + self.0 & DROPPED == DROPPED + } +} diff --git a/third_party/rust/tokio/src/sync/task/atomic_waker.rs b/third_party/rust/tokio/src/sync/task/atomic_waker.rs new file mode 100644 index 0000000000..73b1745f1a --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/atomic_waker.rs @@ -0,0 +1,318 @@ +#![cfg_attr(any(loom, not(feature = "sync")), allow(dead_code, unreachable_pub))] + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{self, AtomicUsize}; + +use std::fmt; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::Waker; + +/// A synchronization primitive for task waking. +/// +/// `AtomicWaker` will coordinate concurrent wakes with the consumer +/// potentially "waking" the underlying task. This is useful in scenarios +/// where a computation completes in another thread and wants to wake the +/// consumer, but the consumer is in the process of being migrated to a new +/// logical task. +/// +/// Consumers should call `register` before checking the result of a computation +/// and producers should call `wake` after producing the computation (this +/// differs from the usual `thread::park` pattern). It is also permitted for +/// `wake` to be called **before** `register`. This results in a no-op. +/// +/// A single `AtomicWaker` may be reused for any number of calls to `register` or +/// `wake`. +pub(crate) struct AtomicWaker { + state: AtomicUsize, + waker: UnsafeCell<Option<Waker>>, +} + +// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell +// stores a `Waker` value produced by calls to `register` and many threads can +// race to take the waker by calling `wake. +// +// If a new `Waker` instance is produced by calling `register` before an existing +// one is consumed, then the existing one is overwritten. +// +// While `AtomicWaker` is single-producer, the implementation ensures memory +// safety. In the event of concurrent calls to `register`, there will be a +// single winner whose waker will get stored in the cell. The losers will not +// have their tasks woken. As such, callers should ensure to add synchronization +// to calls to `register`. +// +// The implementation uses a single `AtomicUsize` value to coordinate access to +// the `Waker` cell. There are two bits that are operated on independently. These +// are represented by `REGISTERING` and `WAKING`. +// +// The `REGISTERING` bit is set when a producer enters the critical section. The +// `WAKING` bit is set when a consumer enters the critical section. Neither +// bit being set is represented by `WAITING`. +// +// A thread obtains an exclusive lock on the waker cell by transitioning the +// state from `WAITING` to `REGISTERING` or `WAKING`, depending on the +// operation the thread wishes to perform. When this transition is made, it is +// guaranteed that no other thread will access the waker cell. +// +// # Registering +// +// On a call to `register`, an attempt to transition the state from WAITING to +// REGISTERING is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread sets the waker cell to the waker +// provided as an argument. Then it attempts to transition the state back from +// `REGISTERING` -> `WAITING`. +// +// If this transition is successful, then the registering process is complete +// and the next call to `wake` will observe the waker. +// +// If the transition fails, then there was a concurrent call to `wake` that +// was unable to access the waker cell (due to the registering thread holding the +// lock). To handle this, the registering thread removes the waker it just set +// from the cell and calls `wake` on it. This call to wake represents the +// attempt to wake by the other thread (that set the `WAKING` bit). The +// state is then transitioned from `REGISTERING | WAKING` back to `WAITING`. +// This transition must succeed because, at this point, the state cannot be +// transitioned by another thread. +// +// # Waking +// +// On a call to `wake`, an attempt to transition the state from `WAITING` to +// `WAKING` is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread takes ownership of the current value +// in the waker cell, and calls `wake` on it. The state is then transitioned +// back to `WAITING`. This transition must succeed as, at this point, the state +// cannot be transitioned by another thread. +// +// If the thread is unable to obtain the lock, the `WAKING` bit is still. +// This is because it has either been set by the current thread but the previous +// value included the `REGISTERING` bit **or** a concurrent thread is in the +// `WAKING` critical section. Either way, no action must be taken. +// +// If the current thread is the only concurrent call to `wake` and another +// thread is in the `register` critical section, when the other thread **exits** +// the `register` critical section, it will observe the `WAKING` bit and +// handle the waker itself. +// +// If another thread is in the `waker` critical section, then it will handle +// waking the caller task. +// +// # A potential race (is safely handled). +// +// Imagine the following situation: +// +// * Thread A obtains the `wake` lock and wakes a task. +// +// * Before thread A releases the `wake` lock, the woken task is scheduled. +// +// * Thread B attempts to wake the task. In theory this should result in the +// task being woken, but it cannot because thread A still holds the wake +// lock. +// +// This case is handled by requiring users of `AtomicWaker` to call `register` +// **before** attempting to observe the application state change that resulted +// in the task being woken. The wakers also change the application state +// before calling wake. +// +// Because of this, the task will do one of two things. +// +// 1) Observe the application state change that Thread B is waking on. In +// this case, it is OK for Thread B's wake to be lost. +// +// 2) Call register before attempting to observe the application state. Since +// Thread A still holds the `wake` lock, the call to `register` will result +// in the task waking itself and get scheduled again. + +/// Idle state +const WAITING: usize = 0; + +/// A new waker value is being registered with the `AtomicWaker` cell. +const REGISTERING: usize = 0b01; + +/// The task currently registered with the `AtomicWaker` cell is being woken. +const WAKING: usize = 0b10; + +impl AtomicWaker { + /// Create an `AtomicWaker` + pub(crate) fn new() -> AtomicWaker { + AtomicWaker { + state: AtomicUsize::new(WAITING), + waker: UnsafeCell::new(None), + } + } + + /// Registers the current waker to be notified on calls to `wake`. + /// + /// This is the same as calling `register_task` with `task::current()`. + #[cfg(feature = "io-driver")] + pub(crate) fn register(&self, waker: Waker) { + self.do_register(waker); + } + + /// Registers the provided waker to be notified on calls to `wake`. + /// + /// The new waker will take place of any previous wakers that were registered + /// by previous calls to `register`. Any calls to `wake` that happen after + /// a call to `register` (as defined by the memory ordering rules), will + /// wake the `register` caller's task. + /// + /// It is safe to call `register` with multiple other threads concurrently + /// calling `wake`. This will result in the `register` caller's current + /// task being woken once. + /// + /// This function is safe to call concurrently, but this is generally a bad + /// idea. Concurrent calls to `register` will attempt to register different + /// tasks to be woken. One of the callers will win and have its task set, + /// but there is no guarantee as to which caller will succeed. + pub(crate) fn register_by_ref(&self, waker: &Waker) { + self.do_register(waker); + } + + fn do_register<W>(&self, waker: W) + where + W: WakerRef, + { + match self.state.compare_and_swap(WAITING, REGISTERING, Acquire) { + WAITING => { + unsafe { + // Locked acquired, update the waker cell + self.waker.with_mut(|t| *t = Some(waker.into_waker())); + + // Release the lock. If the state transitioned to include + // the `WAKING` bit, this means that a wake has been + // called concurrently, so we have to remove the waker and + // wake it.` + // + // Start by assuming that the state is `REGISTERING` as this + // is what we jut set it to. + let res = self + .state + .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); + + match res { + Ok(_) => {} + Err(actual) => { + // This branch can only be reached if a + // concurrent thread called `wake`. In this + // case, `actual` **must** be `REGISTERING | + // `WAKING`. + debug_assert_eq!(actual, REGISTERING | WAKING); + + // Take the waker to wake once the atomic operation has + // completed. + let waker = self.waker.with_mut(|t| (*t).take()).unwrap(); + + // Just swap, because no one could change state + // while state == `Registering | `Waking` + self.state.swap(WAITING, AcqRel); + + // The atomic swap was complete, now + // wake the waker and return. + waker.wake(); + } + } + } + } + WAKING => { + // Currently in the process of waking the task, i.e., + // `wake` is currently being called on the old waker. + // So, we call wake on the new waker. + waker.wake(); + + // This is equivalent to a spin lock, so use a spin hint. + atomic::spin_loop_hint(); + } + state => { + // In this case, a concurrent thread is holding the + // "registering" lock. This probably indicates a bug in the + // caller's code as racing to call `register` doesn't make much + // sense. + // + // We just want to maintain memory safety. It is ok to drop the + // call to `register`. + debug_assert!(state == REGISTERING || state == REGISTERING | WAKING); + } + } + } + + /// Wakes the task that last called `register`. + /// + /// If `register` has not been called yet, then this does nothing. + pub(crate) fn wake(&self) { + if let Some(waker) = self.take_waker() { + waker.wake(); + } + } + + /// Attempts to take the `Waker` value out of the `AtomicWaker` with the + /// intention that the caller will wake the task later. + pub(crate) fn take_waker(&self) -> Option<Waker> { + // AcqRel ordering is used in order to acquire the value of the `waker` + // cell as well as to establish a `release` ordering with whatever + // memory the `AtomicWaker` is associated with. + match self.state.fetch_or(WAKING, AcqRel) { + WAITING => { + // The waking lock has been acquired. + let waker = unsafe { self.waker.with_mut(|t| (*t).take()) }; + + // Release the lock + self.state.fetch_and(!WAKING, Release); + + waker + } + state => { + // There is a concurrent thread currently updating the + // associated waker. + // + // Nothing more to do as the `WAKING` bit has been set. It + // doesn't matter if there are concurrent registering threads or + // not. + // + debug_assert!( + state == REGISTERING || state == REGISTERING | WAKING || state == WAKING + ); + None + } + } + } +} + +impl Default for AtomicWaker { + fn default() -> Self { + AtomicWaker::new() + } +} + +impl fmt::Debug for AtomicWaker { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "AtomicWaker") + } +} + +unsafe impl Send for AtomicWaker {} +unsafe impl Sync for AtomicWaker {} + +trait WakerRef { + fn wake(self); + fn into_waker(self) -> Waker; +} + +impl WakerRef for Waker { + fn wake(self) { + self.wake() + } + + fn into_waker(self) -> Waker { + self + } +} + +impl WakerRef for &Waker { + fn wake(self) { + self.wake_by_ref() + } + + fn into_waker(self) -> Waker { + self.clone() + } +} diff --git a/third_party/rust/tokio/src/sync/task/mod.rs b/third_party/rust/tokio/src/sync/task/mod.rs new file mode 100644 index 0000000000..a6bc6ed06e --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/mod.rs @@ -0,0 +1,4 @@ +//! Thread-safe task notification primitives. + +mod atomic_waker; +pub(crate) use self::atomic_waker::AtomicWaker; diff --git a/third_party/rust/tokio/src/sync/tests/atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs new file mode 100644 index 0000000000..c832d62e9a --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs @@ -0,0 +1,34 @@ +use crate::sync::AtomicWaker; +use tokio_test::task; + +use std::task::Waker; + +trait AssertSend: Send {} +trait AssertSync: Send {} + +impl AssertSend for AtomicWaker {} +impl AssertSync for AtomicWaker {} + +impl AssertSend for Waker {} +impl AssertSync for Waker {} + +#[test] +fn basic_usage() { + let mut waker = task::spawn(AtomicWaker::new()); + + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + waker.wake(); + + assert!(waker.is_woken()); +} + +#[test] +fn wake_without_register() { + let mut waker = task::spawn(AtomicWaker::new()); + waker.wake(); + + // Registering should not result in a notification + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + + assert!(!waker.is_woken()); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs new file mode 100644 index 0000000000..c148bcbe11 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs @@ -0,0 +1,45 @@ +use crate::sync::task::AtomicWaker; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; + +struct Chan { + num: AtomicUsize, + task: AtomicWaker, +} + +#[test] +fn basic_notification() { + const NUM_NOTIFY: usize = 2; + + loom::model(|| { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs new file mode 100644 index 0000000000..da12fb9ff0 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs @@ -0,0 +1,180 @@ +use crate::sync::broadcast; +use crate::sync::broadcast::RecvError::{Closed, Lagged}; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; +use tokio_test::{assert_err, assert_ok}; + +#[test] +fn broadcast_send() { + loom::model(|| { + let (tx1, mut rx) = broadcast::channel(2); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + assert_ok!(tx1.send("one")); + assert_ok!(tx1.send("two")); + assert_ok!(tx1.send("three")); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + assert_ok!(tx2.send("eins")); + assert_ok!(tx2.send("zwei")); + assert_ok!(tx2.send("drei")); + }); + }); + + block_on(async { + let mut num = 0; + loop { + match rx.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + assert_eq!(num, 6); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +// An `Arc` is used as the value in order to detect memory leaks. +#[test] +fn broadcast_two() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel::<Arc<&'static str>>(16); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx2.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + assert_ok!(tx.send(Arc::new("hello"))); + assert_ok!(tx.send(Arc::new("world"))); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn broadcast_wrap() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(2); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx1.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx2.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn drop_rx() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(16); + let rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "one"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "two"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "three"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + drop(rx2); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_list.rs b/third_party/rust/tokio/src/sync/tests/loom_list.rs new file mode 100644 index 0000000000..4067f865ce --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_list.rs @@ -0,0 +1,48 @@ +use crate::sync::mpsc::list; + +use loom::thread; +use std::sync::Arc; + +#[test] +fn smoke() { + use crate::sync::mpsc::block::Read::*; + + const NUM_TX: usize = 2; + const NUM_MSG: usize = 2; + + loom::model(|| { + let (tx, mut rx) = list::channel(); + let tx = Arc::new(tx); + + for th in 0..NUM_TX { + let tx = tx.clone(); + + thread::spawn(move || { + for i in 0..NUM_MSG { + tx.push((th, i)); + } + }); + } + + let mut next = vec![0; NUM_TX]; + + loop { + match rx.pop(&tx) { + Some(Value((th, v))) => { + assert_eq!(v, next[th]); + next[th] += 1; + + if next.iter().all(|&i| i == NUM_MSG) { + break; + } + } + Some(Closed) => { + panic!(); + } + None => { + thread::yield_now(); + } + } + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs new file mode 100644 index 0000000000..6a1a6abedd --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs @@ -0,0 +1,77 @@ +use crate::sync::mpsc; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::thread; + +#[test] +fn closing_tx() { + loom::model(|| { + let (mut tx, mut rx) = mpsc::channel(16); + + thread::spawn(move || { + tx.try_send(()).unwrap(); + drop(tx); + }); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_some()); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_none()); + }); +} + +#[test] +fn closing_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel(); + + thread::spawn(move || { + tx.send(()).unwrap(); + drop(tx); + }); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_some()); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_none()); + }); +} + +#[test] +fn dropping_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::channel::<()>(16); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_none()); + }); +} + +#[test] +fn dropping_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel::<()>(); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + assert!(v.is_none()); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_notify.rs b/third_party/rust/tokio/src/sync/tests/loom_notify.rs new file mode 100644 index 0000000000..60981d4669 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_notify.rs @@ -0,0 +1,90 @@ +use crate::sync::Notify; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; + +#[test] +fn notify_one() { + loom::model(|| { + let tx = Arc::new(Notify::new()); + let rx = tx.clone(); + + let th = thread::spawn(move || { + block_on(async { + rx.notified().await; + }); + }); + + tx.notify(); + th.join().unwrap(); + }); +} + +#[test] +fn notify_multi() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + + let mut ths = vec![]; + + for _ in 0..2 { + let notify = notify.clone(); + + ths.push(thread::spawn(move || { + block_on(async { + notify.notified().await; + notify.notify(); + }) + })); + } + + notify.notify(); + + for th in ths.drain(..) { + th.join().unwrap(); + } + + block_on(async { + notify.notified().await; + }); + }); +} + +#[test] +fn notify_drop() { + use crate::future::poll_fn; + use std::future::Future; + use std::task::Poll; + + loom::model(|| { + let notify = Arc::new(Notify::new()); + let rx1 = notify.clone(); + let rx2 = notify.clone(); + + let th1 = thread::spawn(move || { + let mut recv = Box::pin(rx1.notified()); + + block_on(poll_fn(|cx| { + if recv.as_mut().poll(cx).is_ready() { + rx1.notify(); + } + Poll::Ready(()) + })); + }); + + let th2 = thread::spawn(move || { + block_on(async { + rx2.notified().await; + // Trigger second notification + rx2.notify(); + rx2.notified().await; + }); + }); + + notify.notify(); + + th1.join().unwrap(); + th2.join().unwrap(); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs new file mode 100644 index 0000000000..dfa7459da7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs @@ -0,0 +1,109 @@ +use crate::sync::oneshot; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::thread; +use std::task::Poll::{Pending, Ready}; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + }); +} + +#[test] +fn changing_rx_task() { + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let rx = thread::spawn(move || { + let ready = block_on(poll_fn(|cx| match Pin::new(&mut rx).poll(cx) { + Ready(Ok(value)) => { + assert_eq!(1, value); + Ready(true) + } + Ready(Err(_)) => unimplemented!(), + Pending => Ready(false), + })); + + if ready { + None + } else { + Some(rx) + } + }) + .join() + .unwrap(); + + if let Some(rx) = rx { + // Previous task parked, use a new task... + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + } + }); +} + +// TODO: Move this into `oneshot` proper. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +struct OnClose<'a> { + tx: &'a mut oneshot::Sender<i32>, +} + +impl<'a> OnClose<'a> { + fn new(tx: &'a mut oneshot::Sender<i32>) -> Self { + OnClose { tx } + } +} + +impl Future for OnClose<'_> { + type Output = bool; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> { + let res = self.get_mut().tx.poll_closed(cx); + Ready(res.is_ready()) + } +} + +#[test] +fn changing_tx_task() { + loom::model(|| { + let (mut tx, rx) = oneshot::channel::<i32>(); + + thread::spawn(move || { + drop(rx); + }); + + let tx = thread::spawn(move || { + let t1 = block_on(OnClose::new(&mut tx)); + + if t1 { + None + } else { + Some(tx) + } + }) + .join() + .unwrap(); + + if let Some(mut tx) = tx { + // Previous task parked, use a new task... + block_on(OnClose::new(&mut tx)); + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs new file mode 100644 index 0000000000..48d06e1d5f --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs @@ -0,0 +1,78 @@ +use crate::sync::rwlock::*; + +use loom::future::block_on; +use loom::thread; +use std::sync::Arc; + +#[test] +fn concurrent_write() { + let mut b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + //when all threads have finished the value on the lock should be 10 + let guard = block_on(rwlock.read()); + assert_eq!(10, *guard); + }); +} + +#[test] +fn concurrent_read_write() { + let mut b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t3 = thread::spawn(move || { + block_on(async { + let guard = rwclone.read().await; + //at this state the value on the lock may either be 0, 5, or 10 + assert!(*guard == 0 || *guard == 5 || *guard == 10); + }); + }); + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + t3.join().expect("thread 3 read should not panic"); + + let guard = block_on(rwlock.read()); + //when all threads have finished the value on the lock should be 10 + assert_eq!(10, *guard); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs new file mode 100644 index 0000000000..76a1bc0062 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -0,0 +1,215 @@ +use crate::sync::batch_semaphore::*; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::task::Poll::Ready; +use std::task::{Context, Poll}; + +#[test] +fn basic_usage() { + const NUM: usize = 2; + + struct Shared { + semaphore: Semaphore, + active: AtomicUsize, + } + + async fn actor(shared: Arc<Shared>) { + shared.semaphore.acquire(1).await.unwrap(); + let actual = shared.active.fetch_add(1, SeqCst); + assert!(actual <= NUM - 1); + + let actual = shared.active.fetch_sub(1, SeqCst); + assert!(actual <= NUM); + shared.semaphore.release(1); + } + + loom::model(|| { + let shared = Arc::new(Shared { + semaphore: Semaphore::new(NUM), + active: AtomicUsize::new(0), + }); + + for _ in 0..NUM { + let shared = shared.clone(); + + thread::spawn(move || { + block_on(actor(shared)); + }); + } + + block_on(actor(shared)); + }); +} + +#[test] +fn release() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + { + let semaphore = semaphore.clone(); + thread::spawn(move || { + block_on(semaphore.acquire(1)).unwrap(); + semaphore.release(1); + }); + } + + block_on(semaphore.acquire(1)).unwrap(); + + semaphore.release(1); + }); +} + +#[test] +fn basic_closing() { + const NUM: usize = 2; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + for _ in 0..2 { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + + semaphore.release(1); + } + + Ok::<(), ()>(()) + }); + } + + semaphore.close(); + }); +} + +#[test] +fn concurrent_close() { + const NUM: usize = 3; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + semaphore.release(1); + semaphore.close(); + + Ok::<(), ()>(()) + }); + } + }); +} + +#[test] +fn concurrent_cancel() { + async fn poll_and_cancel(semaphore: Arc<Semaphore>) { + let mut acquire1 = Some(semaphore.acquire(1)); + let mut acquire2 = Some(semaphore.acquire(1)); + poll_fn(|cx| { + // poll the acquire future once, and then immediately throw + // it away. this simulates a situation where a future is + // polled and then cancelled, such as by a timeout. + if let Some(acquire) = acquire1.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + if let Some(acquire) = acquire2.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + Poll::Ready(()) + }) + .await + } + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(0)); + let t1 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t2 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t3 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + + t1.join().unwrap(); + semaphore.release(10); + t2.join().unwrap(); + t3.join().unwrap(); + }); +} + +#[test] +fn batch() { + let mut b = loom::model::Builder::new(); + b.preemption_bound = Some(1); + + b.check(|| { + let semaphore = Arc::new(Semaphore::new(10)); + let active = Arc::new(AtomicUsize::new(0)); + let mut ths = vec![]; + + for _ in 0..2 { + let semaphore = semaphore.clone(); + let active = active.clone(); + + ths.push(thread::spawn(move || { + for n in &[4, 10, 8] { + block_on(semaphore.acquire(*n)).unwrap(); + + active.fetch_add(*n as usize, SeqCst); + + let num_active = active.load(SeqCst); + assert!(num_active <= 10); + + thread::yield_now(); + + active.fetch_sub(*n as usize, SeqCst); + + semaphore.release(*n as usize); + } + })); + } + + for th in ths.into_iter() { + th.join().unwrap(); + } + + assert_eq!(10, semaphore.available_permits()); + }); +} + +#[test] +fn release_during_acquire() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(10)); + semaphore + .try_acquire(8) + .expect("try_acquire should succeed; semaphore uncontended"); + let semaphore2 = semaphore.clone(); + let thread = thread::spawn(move || block_on(semaphore2.acquire(4)).unwrap()); + + semaphore.release(8); + thread.join().unwrap(); + semaphore.release(4); + assert_eq!(10, semaphore.available_permits()); + }) +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_semaphore_ll.rs b/third_party/rust/tokio/src/sync/tests/loom_semaphore_ll.rs new file mode 100644 index 0000000000..b5e5efba82 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_semaphore_ll.rs @@ -0,0 +1,192 @@ +use crate::sync::semaphore_ll::*; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::thread; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::task::Poll::Ready; +use std::task::{Context, Poll}; + +#[test] +fn basic_usage() { + const NUM: usize = 2; + + struct Actor { + waiter: Permit, + shared: Arc<Shared>, + } + + struct Shared { + semaphore: Semaphore, + active: AtomicUsize, + } + + impl Future for Actor { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let me = &mut *self; + + ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap(); + + let actual = me.shared.active.fetch_add(1, SeqCst); + assert!(actual <= NUM - 1); + + let actual = me.shared.active.fetch_sub(1, SeqCst); + assert!(actual <= NUM); + + me.waiter.release(1, &me.shared.semaphore); + + Ready(()) + } + } + + loom::model(|| { + let shared = Arc::new(Shared { + semaphore: Semaphore::new(NUM), + active: AtomicUsize::new(0), + }); + + for _ in 0..NUM { + let shared = shared.clone(); + + thread::spawn(move || { + block_on(Actor { + waiter: Permit::new(), + shared, + }); + }); + } + + block_on(Actor { + waiter: Permit::new(), + shared, + }); + }); +} + +#[test] +fn release() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + { + let semaphore = semaphore.clone(); + thread::spawn(move || { + let mut permit = Permit::new(); + + block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); + + permit.release(1, &semaphore); + }); + } + + let mut permit = Permit::new(); + + block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); + + permit.release(1, &semaphore); + }); +} + +#[test] +fn basic_closing() { + const NUM: usize = 2; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + let mut permit = Permit::new(); + + for _ in 0..2 { + block_on(poll_fn(|cx| { + permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) + }))?; + + permit.release(1, &semaphore); + } + + Ok::<(), ()>(()) + }); + } + + semaphore.close(); + }); +} + +#[test] +fn concurrent_close() { + const NUM: usize = 3; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + let mut permit = Permit::new(); + + block_on(poll_fn(|cx| { + permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) + }))?; + + permit.release(1, &semaphore); + + semaphore.close(); + + Ok::<(), ()>(()) + }); + } + }); +} + +#[test] +fn batch() { + let mut b = loom::model::Builder::new(); + b.preemption_bound = Some(1); + + b.check(|| { + let semaphore = Arc::new(Semaphore::new(10)); + let active = Arc::new(AtomicUsize::new(0)); + let mut ths = vec![]; + + for _ in 0..2 { + let semaphore = semaphore.clone(); + let active = active.clone(); + + ths.push(thread::spawn(move || { + let mut permit = Permit::new(); + + for n in &[4, 10, 8] { + block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap(); + + active.fetch_add(*n as usize, SeqCst); + + let num_active = active.load(SeqCst); + assert!(num_active <= 10); + + thread::yield_now(); + + active.fetch_sub(*n as usize, SeqCst); + + permit.release(*n, &semaphore); + } + })); + } + + for th in ths.into_iter() { + th.join().unwrap(); + } + + assert_eq!(10, semaphore.available_permits()); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/mod.rs b/third_party/rust/tokio/src/sync/tests/mod.rs new file mode 100644 index 0000000000..d571754c01 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/mod.rs @@ -0,0 +1,16 @@ +cfg_not_loom! { + mod atomic_waker; + mod semaphore_ll; + mod semaphore_batch; +} + +cfg_loom! { + mod loom_atomic_waker; + mod loom_broadcast; + mod loom_list; + mod loom_mpsc; + mod loom_notify; + mod loom_oneshot; + mod loom_semaphore_batch; + mod loom_semaphore_ll; +} diff --git a/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs new file mode 100644 index 0000000000..60f3f231e7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs @@ -0,0 +1,250 @@ +use crate::sync::batch_semaphore::Semaphore; +use tokio_test::*; + +#[test] +fn poll_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 99); +} + +#[test] +fn poll_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 95); + + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn try_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 99); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 98); +} + +#[test] +fn try_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 95); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn poll_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 0); + + let mut acquire_2 = task::spawn(s.acquire(1)); + // Try to acquire the second permit + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn poll_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 4); + + // Try to acquire the second permit + let mut acquire_2 = task::spawn(s.acquire(5)); + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the third permit + let mut acquire_3 = task::spawn(s.acquire(3)); + assert_pending!(acquire_3.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(2); + assert!(acquire_3.is_woken()); + + assert_ready_ok!(acquire_3.poll()); +} + +#[test] +fn try_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 0); + + assert_err!(s.try_acquire(1)); + + s.release(1); + + assert_eq!(s.available_permits(), 1); + assert_ok!(s.try_acquire(1)); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn try_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 4); + + assert_err!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 5); + + assert_ok!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 1); + + s.release(1); + assert_eq!(s.available_permits(), 2); +} + +#[test] +fn poll_acquire_one_zero_permits() { + let s = Semaphore::new(0); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the permit + let mut acquire = task::spawn(s.acquire(1)); + assert_pending!(acquire.poll()); + + s.release(1); + + assert!(acquire.is_woken()); + assert_ready_ok!(acquire.poll()); +} + +#[test] +#[should_panic] +fn validates_max_permits() { + use std::usize; + Semaphore::new((usize::MAX >> 2) + 1); +} + +#[test] +fn close_semaphore_prevents_acquire() { + let s = Semaphore::new(5); + s.close(); + + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); +} + +#[test] +fn close_semaphore_notifies_permit1() { + let s = Semaphore::new(0); + let mut acquire = task::spawn(s.acquire(1)); + + assert_pending!(acquire.poll()); + + s.close(); + + assert!(acquire.is_woken()); + assert_ready_err!(acquire.poll()); +} + +#[test] +fn close_semaphore_notifies_permit2() { + let s = Semaphore::new(2); + + // Acquire a couple of permits + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + + let mut acquire3 = task::spawn(s.acquire(1)); + let mut acquire4 = task::spawn(s.acquire(1)); + assert_pending!(acquire3.poll()); + assert_pending!(acquire4.poll()); + + s.close(); + + assert!(acquire3.is_woken()); + assert!(acquire4.is_woken()); + + assert_ready_err!(acquire3.poll()); + assert_ready_err!(acquire4.poll()); + + assert_eq!(0, s.available_permits()); + + s.release(1); + + assert_eq!(1, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + + s.release(1); + + assert_eq!(2, s.available_permits()); +} + +#[test] +fn cancel_acquire_releases_permits() { + let s = Semaphore::new(10); + let _permit1 = s.try_acquire(4).expect("uncontended try_acquire succeeds"); + assert_eq!(6, s.available_permits()); + + let mut acquire = task::spawn(s.acquire(8)); + assert_pending!(acquire.poll()); + + assert_eq!(0, s.available_permits()); + drop(acquire); + + assert_eq!(6, s.available_permits()); + assert_ok!(s.try_acquire(6)); +} diff --git a/third_party/rust/tokio/src/sync/tests/semaphore_ll.rs b/third_party/rust/tokio/src/sync/tests/semaphore_ll.rs new file mode 100644 index 0000000000..bfb075780b --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/semaphore_ll.rs @@ -0,0 +1,470 @@ +use crate::sync::semaphore_ll::{Permit, Semaphore}; +use tokio_test::*; + +#[test] +fn poll_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + let mut permit = task::spawn(Permit::new()); + assert!(!permit.is_acquired()); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 99); + assert!(permit.is_acquired()); + + // Polling again on the same waiter does not claim a new permit + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 99); + assert!(permit.is_acquired()); +} + +#[test] +fn poll_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + let mut permit = task::spawn(Permit::new()); + assert!(!permit.is_acquired()); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); + assert_eq!(s.available_permits(), 95); + assert!(permit.is_acquired()); + + // Polling again on the same waiter does not claim a new permit + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 95); + assert!(permit.is_acquired()); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); + assert_eq!(s.available_permits(), 95); + assert!(permit.is_acquired()); + + // Polling for a larger number of permits acquires more + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s))); + assert_eq!(s.available_permits(), 92); + assert!(permit.is_acquired()); +} + +#[test] +fn try_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + let mut permit = Permit::new(); + assert!(!permit.is_acquired()); + + assert_ok!(permit.try_acquire(1, &s)); + assert_eq!(s.available_permits(), 99); + assert!(permit.is_acquired()); + + // Polling again on the same waiter does not claim a new permit + assert_ok!(permit.try_acquire(1, &s)); + assert_eq!(s.available_permits(), 99); + assert!(permit.is_acquired()); +} + +#[test] +fn try_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + let mut permit = Permit::new(); + assert!(!permit.is_acquired()); + + assert_ok!(permit.try_acquire(5, &s)); + assert_eq!(s.available_permits(), 95); + assert!(permit.is_acquired()); + + // Polling again on the same waiter does not claim a new permit + assert_ok!(permit.try_acquire(5, &s)); + assert_eq!(s.available_permits(), 95); + assert!(permit.is_acquired()); +} + +#[test] +fn poll_acquire_one_unavailable() { + let s = Semaphore::new(1); + + let mut permit_1 = task::spawn(Permit::new()); + let mut permit_2 = task::spawn(Permit::new()); + + // Acquire the first permit + assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 0); + + permit_2.enter(|cx, mut p| { + // Try to acquire the second permit + assert_pending!(p.poll_acquire(cx, 1, &s)); + }); + + permit_1.release(1, &s); + + assert_eq!(s.available_permits(), 0); + assert!(permit_2.is_woken()); + assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + permit_2.release(1, &s); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn forget_acquired() { + let s = Semaphore::new(1); + + // Polling for a permit succeeds immediately + let mut permit = task::spawn(Permit::new()); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + assert_eq!(s.available_permits(), 0); + + permit.forget(1); + assert_eq!(s.available_permits(), 0); +} + +#[test] +fn forget_waiting() { + let s = Semaphore::new(0); + + // Polling for a permit succeeds immediately + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + assert_eq!(s.available_permits(), 0); + + permit.forget(1); + + s.add_permits(1); + + assert!(!permit.is_woken()); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn poll_acquire_many_unavailable() { + let s = Semaphore::new(5); + + let mut permit_1 = task::spawn(Permit::new()); + let mut permit_2 = task::spawn(Permit::new()); + let mut permit_3 = task::spawn(Permit::new()); + + // Acquire the first permit + assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 4); + + permit_2.enter(|cx, mut p| { + // Try to acquire the second permit + assert_pending!(p.poll_acquire(cx, 5, &s)); + }); + + assert_eq!(s.available_permits(), 0); + + permit_3.enter(|cx, mut p| { + // Try to acquire the third permit + assert_pending!(p.poll_acquire(cx, 3, &s)); + }); + + permit_1.release(1, &s); + + assert_eq!(s.available_permits(), 0); + assert!(permit_2.is_woken()); + assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); + + assert!(!permit_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + permit_2.release(1, &s); + assert!(!permit_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + permit_2.release(2, &s); + assert!(permit_3.is_woken()); + + assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); +} + +#[test] +fn try_acquire_one_unavailable() { + let s = Semaphore::new(1); + + let mut permit_1 = Permit::new(); + let mut permit_2 = Permit::new(); + + // Acquire the first permit + assert_ok!(permit_1.try_acquire(1, &s)); + assert_eq!(s.available_permits(), 0); + + assert_err!(permit_2.try_acquire(1, &s)); + + permit_1.release(1, &s); + + assert_eq!(s.available_permits(), 1); + assert_ok!(permit_2.try_acquire(1, &s)); + + permit_2.release(1, &s); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn try_acquire_many_unavailable() { + let s = Semaphore::new(5); + + let mut permit_1 = Permit::new(); + let mut permit_2 = Permit::new(); + + // Acquire the first permit + assert_ok!(permit_1.try_acquire(1, &s)); + assert_eq!(s.available_permits(), 4); + + assert_err!(permit_2.try_acquire(5, &s)); + + permit_1.release(1, &s); + assert_eq!(s.available_permits(), 5); + + assert_ok!(permit_2.try_acquire(5, &s)); + + permit_2.release(1, &s); + assert_eq!(s.available_permits(), 1); + + permit_2.release(1, &s); + assert_eq!(s.available_permits(), 2); +} + +#[test] +fn poll_acquire_one_zero_permits() { + let s = Semaphore::new(0); + assert_eq!(s.available_permits(), 0); + + let mut permit = task::spawn(Permit::new()); + + // Try to acquire the permit + permit.enter(|cx, mut p| { + assert_pending!(p.poll_acquire(cx, 1, &s)); + }); + + s.add_permits(1); + + assert!(permit.is_woken()); + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); +} + +#[test] +#[should_panic] +fn validates_max_permits() { + use std::usize; + Semaphore::new((usize::MAX >> 2) + 1); +} + +#[test] +fn close_semaphore_prevents_acquire() { + let s = Semaphore::new(5); + s.close(); + + assert_eq!(5, s.available_permits()); + + let mut permit_1 = task::spawn(Permit::new()); + let mut permit_2 = task::spawn(Permit::new()); + + assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(5, s.available_permits()); + + assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + assert_eq!(5, s.available_permits()); +} + +#[test] +fn close_semaphore_notifies_permit1() { + let s = Semaphore::new(0); + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + s.close(); + + assert!(permit.is_woken()); + assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); +} + +#[test] +fn close_semaphore_notifies_permit2() { + let s = Semaphore::new(2); + + let mut permit1 = task::spawn(Permit::new()); + let mut permit2 = task::spawn(Permit::new()); + let mut permit3 = task::spawn(Permit::new()); + let mut permit4 = task::spawn(Permit::new()); + + // Acquire a couple of permits + assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + s.close(); + + assert!(permit3.is_woken()); + assert!(permit4.is_woken()); + + assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + assert_eq!(0, s.available_permits()); + + permit1.release(1, &s); + + assert_eq!(1, s.available_permits()); + + assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + permit2.release(1, &s); + + assert_eq!(2, s.available_permits()); +} + +#[test] +fn poll_acquire_additional_permits_while_waiting_before_assigned() { + let s = Semaphore::new(1); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); + + s.add_permits(1); + assert!(!permit.is_woken()); + + s.add_permits(1); + assert!(permit.is_woken()); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); +} + +#[test] +fn try_acquire_additional_permits_while_waiting_before_assigned() { + let s = Semaphore::new(1); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + + assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s))); + + s.add_permits(1); + assert!(permit.is_woken()); + + assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s))); +} + +#[test] +fn poll_acquire_additional_permits_while_waiting_after_assigned_success() { + let s = Semaphore::new(1); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + + s.add_permits(2); + + assert!(permit.is_woken()); + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); +} + +#[test] +fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() { + let s = Semaphore::new(1); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + + s.add_permits(2); + + assert!(permit.is_woken()); + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); + + s.add_permits(1); + + assert!(permit.is_woken()); + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); +} + +#[test] +fn poll_acquire_fewer_permits_while_waiting() { + let s = Semaphore::new(1); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + assert_eq!(s.available_permits(), 0); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + assert_eq!(s.available_permits(), 0); +} + +#[test] +fn poll_acquire_fewer_permits_after_assigned() { + let s = Semaphore::new(1); + + let mut permit1 = task::spawn(Permit::new()); + let mut permit2 = task::spawn(Permit::new()); + + assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); + assert_eq!(s.available_permits(), 0); + + assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + s.add_permits(4); + assert!(permit1.is_woken()); + assert!(!permit2.is_woken()); + + assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); + + assert!(permit2.is_woken()); + assert_eq!(s.available_permits(), 1); + + assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); +} + +#[test] +fn forget_partial_1() { + let s = Semaphore::new(0); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + s.add_permits(1); + + assert_eq!(0, s.available_permits()); + + permit.release(1, &s); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); + + assert_eq!(s.available_permits(), 0); +} + +#[test] +fn forget_partial_2() { + let s = Semaphore::new(0); + + let mut permit = task::spawn(Permit::new()); + + assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + s.add_permits(1); + + assert_eq!(0, s.available_permits()); + + permit.release(1, &s); + + s.add_permits(1); + + assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); + assert_eq!(s.available_permits(), 0); +} diff --git a/third_party/rust/tokio/src/sync/watch.rs b/third_party/rust/tokio/src/sync/watch.rs new file mode 100644 index 0000000000..ba609a8c6d --- /dev/null +++ b/third_party/rust/tokio/src/sync/watch.rs @@ -0,0 +1,432 @@ +//! A single-producer, multi-consumer channel that only retains the *last* sent +//! value. +//! +//! This channel is useful for watching for changes to a value from multiple +//! points in the code base, for example, changes to configuration values. +//! +//! # Usage +//! +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are +//! the producer and sender halves of the channel. The channel is +//! created with an initial value. [`Receiver::recv`] will always +//! be ready upon creation and will yield either this initial value or +//! the latest value that has been sent by `Sender`. +//! +//! Calls to [`Receiver::recv`] will always yield the latest value. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::watch; +//! +//! # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +//! let (tx, mut rx) = watch::channel("hello"); +//! +//! tokio::spawn(async move { +//! while let Some(value) = rx.recv().await { +//! println!("received = {:?}", value); +//! } +//! }); +//! +//! tx.broadcast("world")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Closing +//! +//! [`Sender::closed`] allows the producer to detect when all [`Receiver`] +//! handles have been dropped. This indicates that there is no further interest +//! in the values being produced and work can be stopped. +//! +//! # Thread safety +//! +//! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other +//! threads and can be used in a concurrent environment. Clones of [`Receiver`] +//! handles may be moved to separate threads and also used concurrently. +//! +//! [`Sender`]: crate::sync::watch::Sender +//! [`Receiver`]: crate::sync::watch::Receiver +//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv +//! [`channel`]: crate::sync::watch::channel +//! [`Sender::closed`]: crate::sync::watch::Sender::closed + +use crate::future::poll_fn; +use crate::sync::task::AtomicWaker; + +use fnv::FnvHashSet; +use std::ops; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::{Relaxed, SeqCst}; +use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +/// Receives values from the associated [`Sender`](struct@Sender). +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Receiver<T> { + /// Pointer to the shared state + shared: Arc<Shared<T>>, + + /// Pointer to the watcher's internal state + inner: Watcher, +} + +/// Sends values to the associated [`Receiver`](struct@Receiver). +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Sender<T> { + shared: Weak<Shared<T>>, +} + +/// Returns a reference to the inner value +/// +/// Outstanding borrows hold a read lock on the inner value. This means that +/// long lived borrows could cause the produce half to block. It is recommended +/// to keep the borrow as short lived as possible. +#[derive(Debug)] +pub struct Ref<'a, T> { + inner: RwLockReadGuard<'a, T>, +} + +pub mod error { + //! Watch error types + + use std::fmt; + + /// Error produced when sending a value fails. + #[derive(Debug)] + pub struct SendError<T> { + pub(crate) inner: T, + } + + // ===== impl SendError ===== + + impl<T: fmt::Debug> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} +} + +#[derive(Debug)] +struct Shared<T> { + /// The most recent value + value: RwLock<T>, + + /// The current version + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + version: AtomicUsize, + + /// All watchers + watchers: Mutex<Watchers>, + + /// Task to notify when all watchers drop + cancel: AtomicWaker, +} + +type Watchers = FnvHashSet<Watcher>; + +/// The watcher's ID is based on the Arc's pointer. +#[derive(Clone, Debug)] +struct Watcher(Arc<WatchInner>); + +#[derive(Debug)] +struct WatchInner { + /// Last observed version + version: AtomicUsize, + waker: AtomicWaker, +} + +const CLOSED: usize = 1; + +/// Creates a new watch channel, returning the "send" and "receive" handles. +/// +/// All values sent by [`Sender`] will become visible to the [`Receiver`] handles. +/// Only the last value sent is made available to the [`Receiver`] half. All +/// intermediate values are dropped. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::watch; +/// +/// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +/// let (tx, mut rx) = watch::channel("hello"); +/// +/// tokio::spawn(async move { +/// while let Some(value) = rx.recv().await { +/// println!("received = {:?}", value); +/// } +/// }); +/// +/// tx.broadcast("world")?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [`Sender`]: struct@Sender +/// [`Receiver`]: struct@Receiver +pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) { + const VERSION_0: usize = 0; + const VERSION_1: usize = 2; + + // We don't start knowing VERSION_1 + let inner = Watcher::new_version(VERSION_0); + + // Insert the watcher + let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default()); + watchers.insert(inner.clone()); + + let shared = Arc::new(Shared { + value: RwLock::new(init), + version: AtomicUsize::new(VERSION_1), + watchers: Mutex::new(watchers), + cancel: AtomicWaker::new(), + }); + + let tx = Sender { + shared: Arc::downgrade(&shared), + }; + + let rx = Receiver { shared, inner }; + + (tx, rx) +} + +impl<T> Receiver<T> { + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (_, rx) = watch::channel("hello"); + /// assert_eq!(*rx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + // TODO: document + #[doc(hidden)] + pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Option<Ref<'a, T>>> { + // Make sure the task is up to date + self.inner.waker.register_by_ref(cx.waker()); + + let state = self.shared.version.load(SeqCst); + let version = state & !CLOSED; + + if self.inner.version.swap(version, Relaxed) != version { + let inner = self.shared.value.read().unwrap(); + + return Ready(Some(Ref { inner })); + } + + if CLOSED == state & CLOSED { + // The `Store` handle has been dropped. + return Ready(None); + } + + Pending + } +} + +impl<T: Clone> Receiver<T> { + /// Attempts to clone the latest value sent via the channel. + /// + /// If this is the first time the function is called on a `Receiver` + /// instance, then the function completes immediately with the **current** + /// value held by the channel. On the next call, the function waits until + /// a new value is sent in the channel. + /// + /// `None` is returned if the `Sender` half is dropped. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// let v = rx.recv().await.unwrap(); + /// assert_eq!(v, "hello"); + /// + /// tokio::spawn(async move { + /// tx.broadcast("goodbye").unwrap(); + /// }); + /// + /// // Waits for the new task to spawn and send the value. + /// let v = rx.recv().await.unwrap(); + /// assert_eq!(v, "goodbye"); + /// + /// let v = rx.recv().await; + /// assert!(v.is_none()); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + poll_fn(|cx| { + let v_ref = ready!(self.poll_recv_ref(cx)); + Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + }) + .await + } +} + +#[cfg(feature = "stream")] +impl<T: Clone> crate::stream::Stream for Receiver<T> { + type Item = T; + + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { + let v_ref = ready!(self.poll_recv_ref(cx)); + + Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + } +} + +impl<T> Clone for Receiver<T> { + fn clone(&self) -> Self { + let ver = self.inner.version.load(Relaxed); + let inner = Watcher::new_version(ver); + let shared = self.shared.clone(); + + shared.watchers.lock().unwrap().insert(inner.clone()); + + Receiver { shared, inner } + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + self.shared.watchers.lock().unwrap().remove(&self.inner); + } +} + +impl<T> Sender<T> { + /// Broadcasts a new value via the channel, notifying all receivers. + pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> { + let shared = match self.shared.upgrade() { + Some(shared) => shared, + // All `Watch` handles have been canceled + None => return Err(error::SendError { inner: value }), + }; + + // Replace the value + { + let mut lock = shared.value.write().unwrap(); + *lock = value; + } + + // Update the version. 2 is used so that the CLOSED bit is not set. + shared.version.fetch_add(2, SeqCst); + + // Notify all watchers + notify_all(&*shared); + + // Return the old value + Ok(()) + } + + /// Completes when all receivers have dropped. + /// + /// This allows the producer to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + pub async fn closed(&mut self) { + poll_fn(|cx| self.poll_close(cx)).await + } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self.shared.upgrade() { + Some(shared) => { + shared.cancel.register_by_ref(cx.waker()); + Pending + } + None => Ready(()), + } + } +} + +/// Notifies all watchers of a change +fn notify_all<T>(shared: &Shared<T>) { + let watchers = shared.watchers.lock().unwrap(); + + for watcher in watchers.iter() { + // Notify the task + watcher.waker.wake(); + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if let Some(shared) = self.shared.upgrade() { + shared.version.fetch_or(CLOSED, SeqCst); + notify_all(&*shared); + } + } +} + +// ===== impl Ref ===== + +impl<T> ops::Deref for Ref<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.inner.deref() + } +} + +// ===== impl Shared ===== + +impl<T> Drop for Shared<T> { + fn drop(&mut self) { + self.cancel.wake(); + } +} + +// ===== impl Watcher ===== + +impl Watcher { + fn new_version(version: usize) -> Self { + Watcher(Arc::new(WatchInner { + version: AtomicUsize::new(version), + waker: AtomicWaker::new(), + })) + } +} + +impl std::cmp::PartialEq for Watcher { + fn eq(&self, other: &Watcher) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl std::cmp::Eq for Watcher {} + +impl std::hash::Hash for Watcher { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + (&*self.0 as *const WatchInner).hash(state) + } +} + +impl std::ops::Deref for Watcher { + type Target = WatchInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} |