use crate::loom::sync::Mutex; use crate::sync::watch; #[cfg(all(tokio_unstable, feature = "tracing"))] use crate::util::trace; /// A barrier enables multiple tasks to synchronize the beginning of some computation. /// /// ``` /// # #[tokio::main] /// # async fn main() { /// use tokio::sync::Barrier; /// use std::sync::Arc; /// /// let mut handles = Vec::with_capacity(10); /// let barrier = Arc::new(Barrier::new(10)); /// for _ in 0..10 { /// let c = barrier.clone(); /// // The same messages will be printed together. /// // You will NOT see any interleaving. /// handles.push(tokio::spawn(async move { /// println!("before wait"); /// let wait_result = c.wait().await; /// println!("after wait"); /// wait_result /// })); /// } /// /// // Will not resolve until all "after wait" messages have been printed /// let mut num_leaders = 0; /// for handle in handles { /// let wait_result = handle.await.unwrap(); /// if wait_result.is_leader() { /// num_leaders += 1; /// } /// } /// /// // Exactly one barrier will resolve as the "leader" /// assert_eq!(num_leaders, 1); /// # } /// ``` #[derive(Debug)] pub struct Barrier { state: Mutex, wait: watch::Receiver, n: usize, #[cfg(all(tokio_unstable, feature = "tracing"))] resource_span: tracing::Span, } #[derive(Debug)] struct BarrierState { waker: watch::Sender, arrived: usize, generation: usize, } impl Barrier { /// Creates a new barrier that can block a given number of tasks. /// /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all /// tasks at once when the `n`th task calls `wait`. #[track_caller] pub fn new(mut n: usize) -> Barrier { let (waker, wait) = crate::sync::watch::channel(0); if n == 0 { // if n is 0, it's not clear what behavior the user wants. // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every // .wait() immediately unblocks, so we adopt that here as well. n = 1; } #[cfg(all(tokio_unstable, feature = "tracing"))] let resource_span = { let location = std::panic::Location::caller(); let resource_span = tracing::trace_span!( "runtime.resource", concrete_type = "Barrier", kind = "Sync", loc.file = location.file(), loc.line = location.line(), loc.col = location.column(), ); resource_span.in_scope(|| { tracing::trace!( target: "runtime::resource::state_update", size = n, ); tracing::trace!( target: "runtime::resource::state_update", arrived = 0, ) }); resource_span }; Barrier { state: Mutex::new(BarrierState { waker, arrived: 0, generation: 1, }), n, wait, #[cfg(all(tokio_unstable, feature = "tracing"))] resource_span: resource_span, } } /// Does not resolve until all tasks have rendezvoused here. /// /// Barriers are re-usable after all tasks have rendezvoused once, and can /// be used continuously. /// /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks /// will receive a result that will return `false` from `is_leader`. pub async fn wait(&self) -> BarrierWaitResult { #[cfg(all(tokio_unstable, feature = "tracing"))] return trace::async_op( || self.wait_internal(), self.resource_span.clone(), "Barrier::wait", "poll", false, ) .await; #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] return self.wait_internal().await; } async fn wait_internal(&self) -> BarrierWaitResult { // NOTE: we are taking a _synchronous_ lock here. // It is okay to do so because the critical section is fast and never yields, so it cannot // deadlock even if another future is concurrently holding the lock. // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than // the asynchronous counter-parts, so we should use them where possible [citation needed]. // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across // a yield point, and thus marks the returned future as !Send. let generation = { let mut state = self.state.lock(); let generation = state.generation; state.arrived += 1; #[cfg(all(tokio_unstable, feature = "tracing"))] tracing::trace!( target: "runtime::resource::state_update", arrived = 1, arrived.op = "add", ); #[cfg(all(tokio_unstable, feature = "tracing"))] tracing::trace!( target: "runtime::resource::async_op::state_update", arrived = true, ); if state.arrived == self.n { #[cfg(all(tokio_unstable, feature = "tracing"))] tracing::trace!( target: "runtime::resource::async_op::state_update", is_leader = true, ); // we are the leader for this generation // wake everyone, increment the generation, and return state .waker .send(state.generation) .expect("there is at least one receiver"); state.arrived = 0; state.generation += 1; return BarrierWaitResult(true); } generation }; // we're going to have to wait for the last of the generation to arrive let mut wait = self.wait.clone(); loop { let _ = wait.changed().await; // note that the first time through the loop, this _will_ yield a generation // immediately, since we cloned a receiver that has never seen any values. if *wait.borrow() >= generation { break; } } BarrierWaitResult(false) } } /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. #[derive(Debug, Clone)] pub struct BarrierWaitResult(bool); impl BarrierWaitResult { /// Returns `true` if this task from wait is the "leader task". /// /// Only one task will have `true` returned from their result, all other tasks will have /// `false` returned. pub fn is_leader(&self) -> bool { self.0 } }