summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio/src/sync/barrier.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio/src/sync/barrier.rs')
-rw-r--r--third_party/rust/tokio/src/sync/barrier.rs206
1 files changed, 206 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/sync/barrier.rs b/third_party/rust/tokio/src/sync/barrier.rs
new file mode 100644
index 0000000000..dfc76a40eb
--- /dev/null
+++ b/third_party/rust/tokio/src/sync/barrier.rs
@@ -0,0 +1,206 @@
+use crate::loom::sync::Mutex;
+use crate::sync::watch;
+#[cfg(all(tokio_unstable, feature = "tracing"))]
+use crate::util::trace;
+
+/// A barrier enables multiple tasks to synchronize the beginning of some computation.
+///
+/// ```
+/// # #[tokio::main]
+/// # async fn main() {
+/// use tokio::sync::Barrier;
+/// use std::sync::Arc;
+///
+/// let mut handles = Vec::with_capacity(10);
+/// let barrier = Arc::new(Barrier::new(10));
+/// for _ in 0..10 {
+/// let c = barrier.clone();
+/// // The same messages will be printed together.
+/// // You will NOT see any interleaving.
+/// handles.push(tokio::spawn(async move {
+/// println!("before wait");
+/// let wait_result = c.wait().await;
+/// println!("after wait");
+/// wait_result
+/// }));
+/// }
+///
+/// // Will not resolve until all "after wait" messages have been printed
+/// let mut num_leaders = 0;
+/// for handle in handles {
+/// let wait_result = handle.await.unwrap();
+/// if wait_result.is_leader() {
+/// num_leaders += 1;
+/// }
+/// }
+///
+/// // Exactly one barrier will resolve as the "leader"
+/// assert_eq!(num_leaders, 1);
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct Barrier {
+ state: Mutex<BarrierState>,
+ wait: watch::Receiver<usize>,
+ n: usize,
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ resource_span: tracing::Span,
+}
+
+#[derive(Debug)]
+struct BarrierState {
+ waker: watch::Sender<usize>,
+ arrived: usize,
+ generation: usize,
+}
+
+impl Barrier {
+ /// Creates a new barrier that can block a given number of tasks.
+ ///
+ /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
+ /// tasks at once when the `n`th task calls `wait`.
+ #[track_caller]
+ pub fn new(mut n: usize) -> Barrier {
+ let (waker, wait) = crate::sync::watch::channel(0);
+
+ if n == 0 {
+ // if n is 0, it's not clear what behavior the user wants.
+ // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
+ // .wait() immediately unblocks, so we adopt that here as well.
+ n = 1;
+ }
+
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ let resource_span = {
+ let location = std::panic::Location::caller();
+ let resource_span = tracing::trace_span!(
+ "runtime.resource",
+ concrete_type = "Barrier",
+ kind = "Sync",
+ loc.file = location.file(),
+ loc.line = location.line(),
+ loc.col = location.column(),
+ );
+
+ resource_span.in_scope(|| {
+ tracing::trace!(
+ target: "runtime::resource::state_update",
+ size = n,
+ );
+
+ tracing::trace!(
+ target: "runtime::resource::state_update",
+ arrived = 0,
+ )
+ });
+ resource_span
+ };
+
+ Barrier {
+ state: Mutex::new(BarrierState {
+ waker,
+ arrived: 0,
+ generation: 1,
+ }),
+ n,
+ wait,
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ resource_span: resource_span,
+ }
+ }
+
+ /// Does not resolve until all tasks have rendezvoused here.
+ ///
+ /// Barriers are re-usable after all tasks have rendezvoused once, and can
+ /// be used continuously.
+ ///
+ /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
+ /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
+ /// will receive a result that will return `false` from `is_leader`.
+ pub async fn wait(&self) -> BarrierWaitResult {
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ return trace::async_op(
+ || self.wait_internal(),
+ self.resource_span.clone(),
+ "Barrier::wait",
+ "poll",
+ false,
+ )
+ .await;
+
+ #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
+ return self.wait_internal().await;
+ }
+ async fn wait_internal(&self) -> BarrierWaitResult {
+ // NOTE: we are taking a _synchronous_ lock here.
+ // It is okay to do so because the critical section is fast and never yields, so it cannot
+ // deadlock even if another future is concurrently holding the lock.
+ // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than
+ // the asynchronous counter-parts, so we should use them where possible [citation needed].
+ // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
+ // a yield point, and thus marks the returned future as !Send.
+ let generation = {
+ let mut state = self.state.lock();
+ let generation = state.generation;
+ state.arrived += 1;
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ tracing::trace!(
+ target: "runtime::resource::state_update",
+ arrived = 1,
+ arrived.op = "add",
+ );
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ tracing::trace!(
+ target: "runtime::resource::async_op::state_update",
+ arrived = true,
+ );
+ if state.arrived == self.n {
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ tracing::trace!(
+ target: "runtime::resource::async_op::state_update",
+ is_leader = true,
+ );
+ // we are the leader for this generation
+ // wake everyone, increment the generation, and return
+ state
+ .waker
+ .send(state.generation)
+ .expect("there is at least one receiver");
+ state.arrived = 0;
+ state.generation += 1;
+ return BarrierWaitResult(true);
+ }
+
+ generation
+ };
+
+ // we're going to have to wait for the last of the generation to arrive
+ let mut wait = self.wait.clone();
+
+ loop {
+ let _ = wait.changed().await;
+
+ // note that the first time through the loop, this _will_ yield a generation
+ // immediately, since we cloned a receiver that has never seen any values.
+ if *wait.borrow() >= generation {
+ break;
+ }
+ }
+
+ BarrierWaitResult(false)
+ }
+}
+
+/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
+#[derive(Debug, Clone)]
+pub struct BarrierWaitResult(bool);
+
+impl BarrierWaitResult {
+ /// Returns `true` if this task from wait is the "leader task".
+ ///
+ /// Only one task will have `true` returned from their result, all other tasks will have
+ /// `false` returned.
+ pub fn is_leader(&self) -> bool {
+ self.0
+ }
+}