summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio/src/sync/barrier.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio/src/sync/barrier.rs')
-rw-r--r--third_party/rust/tokio/src/sync/barrier.rs136
1 files changed, 136 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/sync/barrier.rs b/third_party/rust/tokio/src/sync/barrier.rs
new file mode 100644
index 0000000000..628633493a
--- /dev/null
+++ b/third_party/rust/tokio/src/sync/barrier.rs
@@ -0,0 +1,136 @@
+use crate::sync::watch;
+
+use std::sync::Mutex;
+
+/// A barrier enables multiple threads to synchronize the beginning of some computation.
+///
+/// ```
+/// # #[tokio::main]
+/// # async fn main() {
+/// use tokio::sync::Barrier;
+///
+/// use futures::future::join_all;
+/// use std::sync::Arc;
+///
+/// let mut handles = Vec::with_capacity(10);
+/// let barrier = Arc::new(Barrier::new(10));
+/// for _ in 0..10 {
+/// let c = barrier.clone();
+/// // The same messages will be printed together.
+/// // You will NOT see any interleaving.
+/// handles.push(async move {
+/// println!("before wait");
+/// let wr = c.wait().await;
+/// println!("after wait");
+/// wr
+/// });
+/// }
+/// // Will not resolve until all "before wait" messages have been printed
+/// let wrs = join_all(handles).await;
+/// // Exactly one barrier will resolve as the "leader"
+/// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1);
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct Barrier {
+ state: Mutex<BarrierState>,
+ wait: watch::Receiver<usize>,
+ n: usize,
+}
+
+#[derive(Debug)]
+struct BarrierState {
+ waker: watch::Sender<usize>,
+ arrived: usize,
+ generation: usize,
+}
+
+impl Barrier {
+ /// Creates a new barrier that can block a given number of threads.
+ ///
+ /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all
+ /// threads at once when the `n`th thread calls `wait`.
+ pub fn new(mut n: usize) -> Barrier {
+ let (waker, wait) = crate::sync::watch::channel(0);
+
+ if n == 0 {
+ // if n is 0, it's not clear what behavior the user wants.
+ // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
+ // .wait() immediately unblocks, so we adopt that here as well.
+ n = 1;
+ }
+
+ Barrier {
+ state: Mutex::new(BarrierState {
+ waker,
+ arrived: 0,
+ generation: 1,
+ }),
+ n,
+ wait,
+ }
+ }
+
+ /// Does not resolve until all tasks have rendezvoused here.
+ ///
+ /// Barriers are re-usable after all threads have rendezvoused once, and can
+ /// be used continuously.
+ ///
+ /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
+ /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
+ /// will receive a result that will return `false` from `is_leader`.
+ pub async fn wait(&self) -> BarrierWaitResult {
+ // NOTE: we are taking a _synchronous_ lock here.
+ // It is okay to do so because the critical section is fast and never yields, so it cannot
+ // deadlock even if another future is concurrently holding the lock.
+ // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than
+ // the asynchronous counter-parts, so we should use them where possible [citation needed].
+ // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
+ // a yield point, and thus marks the returned future as !Send.
+ let generation = {
+ let mut state = self.state.lock().unwrap();
+ let generation = state.generation;
+ state.arrived += 1;
+ if state.arrived == self.n {
+ // we are the leader for this generation
+ // wake everyone, increment the generation, and return
+ state
+ .waker
+ .broadcast(state.generation)
+ .expect("there is at least one receiver");
+ state.arrived = 0;
+ state.generation += 1;
+ return BarrierWaitResult(true);
+ }
+
+ generation
+ };
+
+ // we're going to have to wait for the last of the generation to arrive
+ let mut wait = self.wait.clone();
+
+ loop {
+ // note that the first time through the loop, this _will_ yield a generation
+ // immediately, since we cloned a receiver that has never seen any values.
+ if wait.recv().await.expect("sender hasn't been closed") >= generation {
+ break;
+ }
+ }
+
+ BarrierWaitResult(false)
+ }
+}
+
+/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
+#[derive(Debug, Clone)]
+pub struct BarrierWaitResult(bool);
+
+impl BarrierWaitResult {
+ /// Returns `true` if this thread from wait is the "leader thread".
+ ///
+ /// Only one thread will have `true` returned from their result, all other threads will have
+ /// `false` returned.
+ pub fn is_leader(&self) -> bool {
+ self.0
+ }
+}