From 698f8c2f01ea549d77d7dc3338a12e04c11057b9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 17 Apr 2024 14:02:58 +0200 Subject: Adding upstream version 1.64.0+dfsg1. Signed-off-by: Daniel Baumann --- library/std/src/sync/mpsc/shared.rs | 501 ++++++++++++++++++++++++++++++++++++ 1 file changed, 501 insertions(+) create mode 100644 library/std/src/sync/mpsc/shared.rs (limited to 'library/std/src/sync/mpsc/shared.rs') diff --git a/library/std/src/sync/mpsc/shared.rs b/library/std/src/sync/mpsc/shared.rs new file mode 100644 index 000000000..51917bd96 --- /dev/null +++ b/library/std/src/sync/mpsc/shared.rs @@ -0,0 +1,501 @@ +/// Shared channels. +/// +/// This is the flavor of channels which are not necessarily optimized for any +/// particular use case, but are the most general in how they are used. Shared +/// channels are cloneable allowing for multiple senders. +/// +/// High level implementation details can be found in the comment of the parent +/// module. You'll also note that the implementation of the shared and stream +/// channels are quite similar, and this is no coincidence! +pub use self::Failure::*; +use self::StartResult::*; + +use core::cmp; +use core::intrinsics::abort; + +use crate::cell::UnsafeCell; +use crate::ptr; +use crate::sync::atomic::{AtomicBool, AtomicIsize, AtomicPtr, AtomicUsize, Ordering}; +use crate::sync::mpsc::blocking::{self, SignalToken}; +use crate::sync::mpsc::mpsc_queue as mpsc; +use crate::sync::{Mutex, MutexGuard}; +use crate::thread; +use crate::time::Instant; + +const DISCONNECTED: isize = isize::MIN; +const FUDGE: isize = 1024; +const MAX_REFCOUNT: usize = (isize::MAX) as usize; +#[cfg(test)] +const MAX_STEALS: isize = 5; +#[cfg(not(test))] +const MAX_STEALS: isize = 1 << 20; +const EMPTY: *mut u8 = ptr::null_mut(); // initial state: no data, no blocked receiver + +pub struct Packet { + queue: mpsc::Queue, + cnt: AtomicIsize, // How many items are on this channel + steals: UnsafeCell, // How many times has a port received without blocking? + to_wake: AtomicPtr, // SignalToken for wake up + + // The number of channels which are currently using this packet. + channels: AtomicUsize, + + // See the discussion in Port::drop and the channel send methods for what + // these are used for + port_dropped: AtomicBool, + sender_drain: AtomicIsize, + + // this lock protects various portions of this implementation during + // select() + select_lock: Mutex<()>, +} + +pub enum Failure { + Empty, + Disconnected, +} + +#[derive(PartialEq, Eq)] +enum StartResult { + Installed, + Abort, +} + +impl Packet { + // Creation of a packet *must* be followed by a call to postinit_lock + // and later by inherit_blocker + pub fn new() -> Packet { + Packet { + queue: mpsc::Queue::new(), + cnt: AtomicIsize::new(0), + steals: UnsafeCell::new(0), + to_wake: AtomicPtr::new(EMPTY), + channels: AtomicUsize::new(2), + port_dropped: AtomicBool::new(false), + sender_drain: AtomicIsize::new(0), + select_lock: Mutex::new(()), + } + } + + // This function should be used after newly created Packet + // was wrapped with an Arc + // In other case mutex data will be duplicated while cloning + // and that could cause problems on platforms where it is + // represented by opaque data structure + pub fn postinit_lock(&self) -> MutexGuard<'_, ()> { + self.select_lock.lock().unwrap() + } + + // This function is used at the creation of a shared packet to inherit a + // previously blocked thread. This is done to prevent spurious wakeups of + // threads in select(). + // + // This can only be called at channel-creation time + pub fn inherit_blocker(&self, token: Option, guard: MutexGuard<'_, ()>) { + if let Some(token) = token { + assert_eq!(self.cnt.load(Ordering::SeqCst), 0); + assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY); + self.to_wake.store(unsafe { token.to_raw() }, Ordering::SeqCst); + self.cnt.store(-1, Ordering::SeqCst); + + // This store is a little sketchy. What's happening here is that + // we're transferring a blocker from a oneshot or stream channel to + // this shared channel. In doing so, we never spuriously wake them + // up and rather only wake them up at the appropriate time. This + // implementation of shared channels assumes that any blocking + // recv() will undo the increment of steals performed in try_recv() + // once the recv is complete. This thread that we're inheriting, + // however, is not in the middle of recv. Hence, the first time we + // wake them up, they're going to wake up from their old port, move + // on to the upgraded port, and then call the block recv() function. + // + // When calling this function, they'll find there's data immediately + // available, counting it as a steal. This in fact wasn't a steal + // because we appropriately blocked them waiting for data. + // + // To offset this bad increment, we initially set the steal count to + // -1. You'll find some special code in abort_selection() as well to + // ensure that this -1 steal count doesn't escape too far. + unsafe { + *self.steals.get() = -1; + } + } + + // When the shared packet is constructed, we grabbed this lock. The + // purpose of this lock is to ensure that abort_selection() doesn't + // interfere with this method. After we unlock this lock, we're + // signifying that we're done modifying self.cnt and self.to_wake and + // the port is ready for the world to continue using it. + drop(guard); + } + + pub fn send(&self, t: T) -> Result<(), T> { + // See Port::drop for what's going on + if self.port_dropped.load(Ordering::SeqCst) { + return Err(t); + } + + // Note that the multiple sender case is a little trickier + // semantically than the single sender case. The logic for + // incrementing is "add and if disconnected store disconnected". + // This could end up leading some senders to believe that there + // wasn't a disconnect if in fact there was a disconnect. This means + // that while one thread is attempting to re-store the disconnected + // states, other threads could walk through merrily incrementing + // this very-negative disconnected count. To prevent senders from + // spuriously attempting to send when the channels is actually + // disconnected, the count has a ranged check here. + // + // This is also done for another reason. Remember that the return + // value of this function is: + // + // `true` == the data *may* be received, this essentially has no + // meaning + // `false` == the data will *never* be received, this has a lot of + // meaning + // + // In the SPSC case, we have a check of 'queue.is_empty()' to see + // whether the data was actually received, but this same condition + // means nothing in a multi-producer context. As a result, this + // preflight check serves as the definitive "this will never be + // received". Once we get beyond this check, we have permanently + // entered the realm of "this may be received" + if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE { + return Err(t); + } + + self.queue.push(t); + match self.cnt.fetch_add(1, Ordering::SeqCst) { + -1 => { + self.take_to_wake().signal(); + } + + // In this case, we have possibly failed to send our data, and + // we need to consider re-popping the data in order to fully + // destroy it. We must arbitrate among the multiple senders, + // however, because the queues that we're using are + // single-consumer queues. In order to do this, all exiting + // pushers will use an atomic count in order to count those + // flowing through. Pushers who see 0 are required to drain as + // much as possible, and then can only exit when they are the + // only pusher (otherwise they must try again). + n if n < DISCONNECTED + FUDGE => { + // see the comment in 'try' for a shared channel for why this + // window of "not disconnected" is ok. + self.cnt.store(DISCONNECTED, Ordering::SeqCst); + + if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 { + loop { + // drain the queue, for info on the thread yield see the + // discussion in try_recv + loop { + match self.queue.pop() { + mpsc::Data(..) => {} + mpsc::Empty => break, + mpsc::Inconsistent => thread::yield_now(), + } + } + // maybe we're done, if we're not the last ones + // here, then we need to go try again. + if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 { + break; + } + } + + // At this point, there may still be data on the queue, + // but only if the count hasn't been incremented and + // some other sender hasn't finished pushing data just + // yet. That sender in question will drain its own data. + } + } + + // Can't make any assumptions about this case like in the SPSC case. + _ => {} + } + + Ok(()) + } + + pub fn recv(&self, deadline: Option) -> Result { + // This code is essentially the exact same as that found in the stream + // case (see stream.rs) + match self.try_recv() { + Err(Empty) => {} + data => return data, + } + + let (wait_token, signal_token) = blocking::tokens(); + if self.decrement(signal_token) == Installed { + if let Some(deadline) = deadline { + let timed_out = !wait_token.wait_max_until(deadline); + if timed_out { + self.abort_selection(false); + } + } else { + wait_token.wait(); + } + } + + match self.try_recv() { + data @ Ok(..) => unsafe { + *self.steals.get() -= 1; + data + }, + data => data, + } + } + + // Essentially the exact same thing as the stream decrement function. + // Returns true if blocking should proceed. + fn decrement(&self, token: SignalToken) -> StartResult { + unsafe { + assert_eq!( + self.to_wake.load(Ordering::SeqCst), + EMPTY, + "This is a known bug in the Rust standard library. See https://github.com/rust-lang/rust/issues/39364" + ); + let ptr = token.to_raw(); + self.to_wake.store(ptr, Ordering::SeqCst); + + let steals = ptr::replace(self.steals.get(), 0); + + match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) { + DISCONNECTED => { + self.cnt.store(DISCONNECTED, Ordering::SeqCst); + } + // If we factor in our steals and notice that the channel has no + // data, we successfully sleep + n => { + assert!(n >= 0); + if n - steals <= 0 { + return Installed; + } + } + } + + self.to_wake.store(EMPTY, Ordering::SeqCst); + drop(SignalToken::from_raw(ptr)); + Abort + } + } + + pub fn try_recv(&self) -> Result { + let ret = match self.queue.pop() { + mpsc::Data(t) => Some(t), + mpsc::Empty => None, + + // This is a bit of an interesting case. The channel is reported as + // having data available, but our pop() has failed due to the queue + // being in an inconsistent state. This means that there is some + // pusher somewhere which has yet to complete, but we are guaranteed + // that a pop will eventually succeed. In this case, we spin in a + // yield loop because the remote sender should finish their enqueue + // operation "very quickly". + // + // Avoiding this yield loop would require a different queue + // abstraction which provides the guarantee that after M pushes have + // succeeded, at least M pops will succeed. The current queues + // guarantee that if there are N active pushes, you can pop N times + // once all N have finished. + mpsc::Inconsistent => { + let data; + loop { + thread::yield_now(); + match self.queue.pop() { + mpsc::Data(t) => { + data = t; + break; + } + mpsc::Empty => panic!("inconsistent => empty"), + mpsc::Inconsistent => {} + } + } + Some(data) + } + }; + match ret { + // See the discussion in the stream implementation for why we + // might decrement steals. + Some(data) => unsafe { + if *self.steals.get() > MAX_STEALS { + match self.cnt.swap(0, Ordering::SeqCst) { + DISCONNECTED => { + self.cnt.store(DISCONNECTED, Ordering::SeqCst); + } + n => { + let m = cmp::min(n, *self.steals.get()); + *self.steals.get() -= m; + self.bump(n - m); + } + } + assert!(*self.steals.get() >= 0); + } + *self.steals.get() += 1; + Ok(data) + }, + + // See the discussion in the stream implementation for why we try + // again. + None => { + match self.cnt.load(Ordering::SeqCst) { + n if n != DISCONNECTED => Err(Empty), + _ => { + match self.queue.pop() { + mpsc::Data(t) => Ok(t), + mpsc::Empty => Err(Disconnected), + // with no senders, an inconsistency is impossible. + mpsc::Inconsistent => unreachable!(), + } + } + } + } + } + } + + // Prepares this shared packet for a channel clone, essentially just bumping + // a refcount. + pub fn clone_chan(&self) { + let old_count = self.channels.fetch_add(1, Ordering::SeqCst); + + // See comments on Arc::clone() on why we do this (for `mem::forget`). + if old_count > MAX_REFCOUNT { + abort(); + } + } + + // Decrement the reference count on a channel. This is called whenever a + // Chan is dropped and may end up waking up a receiver. It's the receiver's + // responsibility on the other end to figure out that we've disconnected. + pub fn drop_chan(&self) { + match self.channels.fetch_sub(1, Ordering::SeqCst) { + 1 => {} + n if n > 1 => return, + n => panic!("bad number of channels left {n}"), + } + + match self.cnt.swap(DISCONNECTED, Ordering::SeqCst) { + -1 => { + self.take_to_wake().signal(); + } + DISCONNECTED => {} + n => { + assert!(n >= 0); + } + } + } + + // See the long discussion inside of stream.rs for why the queue is drained, + // and why it is done in this fashion. + pub fn drop_port(&self) { + self.port_dropped.store(true, Ordering::SeqCst); + let mut steals = unsafe { *self.steals.get() }; + while { + match self.cnt.compare_exchange( + steals, + DISCONNECTED, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => false, + Err(old) => old != DISCONNECTED, + } + } { + // See the discussion in 'try_recv' for why we yield + // control of this thread. + loop { + match self.queue.pop() { + mpsc::Data(..) => { + steals += 1; + } + mpsc::Empty | mpsc::Inconsistent => break, + } + } + } + } + + // Consumes ownership of the 'to_wake' field. + fn take_to_wake(&self) -> SignalToken { + let ptr = self.to_wake.load(Ordering::SeqCst); + self.to_wake.store(EMPTY, Ordering::SeqCst); + assert!(ptr != EMPTY); + unsafe { SignalToken::from_raw(ptr) } + } + + //////////////////////////////////////////////////////////////////////////// + // select implementation + //////////////////////////////////////////////////////////////////////////// + + // increment the count on the channel (used for selection) + fn bump(&self, amt: isize) -> isize { + match self.cnt.fetch_add(amt, Ordering::SeqCst) { + DISCONNECTED => { + self.cnt.store(DISCONNECTED, Ordering::SeqCst); + DISCONNECTED + } + n => n, + } + } + + // Cancels a previous thread waiting on this port, returning whether there's + // data on the port. + // + // This is similar to the stream implementation (hence fewer comments), but + // uses a different value for the "steals" variable. + pub fn abort_selection(&self, _was_upgrade: bool) -> bool { + // Before we do anything else, we bounce on this lock. The reason for + // doing this is to ensure that any upgrade-in-progress is gone and + // done with. Without this bounce, we can race with inherit_blocker + // about looking at and dealing with to_wake. Once we have acquired the + // lock, we are guaranteed that inherit_blocker is done. + { + let _guard = self.select_lock.lock().unwrap(); + } + + // Like the stream implementation, we want to make sure that the count + // on the channel goes non-negative. We don't know how negative the + // stream currently is, so instead of using a steal value of 1, we load + // the channel count and figure out what we should do to make it + // positive. + let steals = { + let cnt = self.cnt.load(Ordering::SeqCst); + if cnt < 0 && cnt != DISCONNECTED { -cnt } else { 0 } + }; + let prev = self.bump(steals + 1); + + if prev == DISCONNECTED { + assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY); + true + } else { + let cur = prev + steals + 1; + assert!(cur >= 0); + if prev < 0 { + drop(self.take_to_wake()); + } else { + while self.to_wake.load(Ordering::SeqCst) != EMPTY { + thread::yield_now(); + } + } + unsafe { + // if the number of steals is -1, it was the pre-emptive -1 steal + // count from when we inherited a blocker. This is fine because + // we're just going to overwrite it with a real value. + let old = self.steals.get(); + assert!(*old == 0 || *old == -1); + *old = steals; + prev >= 0 + } + } + } +} + +impl Drop for Packet { + fn drop(&mut self) { + // Note that this load is not only an assert for correctness about + // disconnection, but also a proper fence before the read of + // `to_wake`, so this assert cannot be removed with also removing + // the `to_wake` assert. + assert_eq!(self.cnt.load(Ordering::SeqCst), DISCONNECTED); + assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY); + assert_eq!(self.channels.load(Ordering::SeqCst), 0); + } +} -- cgit v1.2.3