//! A `Barrier` that provides `wait_timeout`. //! //! This implementation mirrors that of the Rust standard library. use crate::loom::sync::{Condvar, Mutex}; use std::fmt; use std::time::{Duration, Instant}; /// A barrier enables multiple threads to synchronize the beginning /// of some computation. /// /// # Examples /// /// ``` /// use std::sync::{Arc, Barrier}; /// use std::thread; /// /// let mut handles = Vec::with_capacity(10); /// let barrier = Arc::new(Barrier::new(10)); /// for _ in 0..10 { /// let c = Arc::clone(&barrier); /// // The same messages will be printed together. /// // You will NOT see any interleaving. /// handles.push(thread::spawn(move|| { /// println!("before wait"); /// c.wait(); /// println!("after wait"); /// })); /// } /// // Wait for other threads to finish. /// for handle in handles { /// handle.join().unwrap(); /// } /// ``` pub(crate) struct Barrier { lock: Mutex, cvar: Condvar, num_threads: usize, } // The inner state of a double barrier struct BarrierState { count: usize, generation_id: usize, } /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads /// in the [`Barrier`] have rendezvoused. /// /// # Examples /// /// ``` /// use std::sync::Barrier; /// /// let barrier = Barrier::new(1); /// let barrier_wait_result = barrier.wait(); /// ``` pub(crate) struct BarrierWaitResult(bool); impl fmt::Debug for Barrier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Barrier").finish_non_exhaustive() } } impl Barrier { /// Creates a new barrier that can block a given number of threads. /// /// A barrier will block `n`-1 threads which call [`wait()`] and then wake /// up all threads at once when the `n`th thread calls [`wait()`]. /// /// [`wait()`]: Barrier::wait /// /// # Examples /// /// ``` /// use std::sync::Barrier; /// /// let barrier = Barrier::new(10); /// ``` #[must_use] pub(crate) fn new(n: usize) -> Barrier { Barrier { lock: Mutex::new(BarrierState { count: 0, generation_id: 0, }), cvar: Condvar::new(), num_threads: n, } } /// Blocks the current thread until all threads have rendezvoused here. /// /// Barriers are re-usable after all threads have rendezvoused once, and can /// be used continuously. /// /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning /// from this function, and all other threads will receive a result that /// will return `false` from [`BarrierWaitResult::is_leader()`]. /// /// # Examples /// /// ``` /// use std::sync::{Arc, Barrier}; /// use std::thread; /// /// let mut handles = Vec::with_capacity(10); /// let barrier = Arc::new(Barrier::new(10)); /// for _ in 0..10 { /// let c = Arc::clone(&barrier); /// // The same messages will be printed together. /// // You will NOT see any interleaving. /// handles.push(thread::spawn(move|| { /// println!("before wait"); /// c.wait(); /// println!("after wait"); /// })); /// } /// // Wait for other threads to finish. /// for handle in handles { /// handle.join().unwrap(); /// } /// ``` pub(crate) fn wait(&self) -> BarrierWaitResult { let mut lock = self.lock.lock(); let local_gen = lock.generation_id; lock.count += 1; if lock.count < self.num_threads { // We need a while loop to guard against spurious wakeups. // https://en.wikipedia.org/wiki/Spurious_wakeup while local_gen == lock.generation_id { lock = self.cvar.wait(lock).unwrap(); } BarrierWaitResult(false) } else { lock.count = 0; lock.generation_id = lock.generation_id.wrapping_add(1); self.cvar.notify_all(); BarrierWaitResult(true) } } /// Blocks the current thread until all threads have rendezvoused here for /// at most `timeout` duration. pub(crate) fn wait_timeout(&self, timeout: Duration) -> Option { // This implementation mirrors `wait`, but with each blocking operation // replaced by a timeout-amenable alternative. let deadline = Instant::now() + timeout; // Acquire `self.lock` with at most `timeout` duration. let mut lock = loop { if let Some(guard) = self.lock.try_lock() { break guard; } else if Instant::now() > deadline { return None; } else { std::thread::yield_now(); } }; // Shrink the `timeout` to account for the time taken to acquire `lock`. let timeout = deadline.saturating_duration_since(Instant::now()); let local_gen = lock.generation_id; lock.count += 1; if lock.count < self.num_threads { // We need a while loop to guard against spurious wakeups. // https://en.wikipedia.org/wiki/Spurious_wakeup while local_gen == lock.generation_id { let (guard, timeout_result) = self.cvar.wait_timeout(lock, timeout).unwrap(); lock = guard; if timeout_result.timed_out() { return None; } } Some(BarrierWaitResult(false)) } else { lock.count = 0; lock.generation_id = lock.generation_id.wrapping_add(1); self.cvar.notify_all(); Some(BarrierWaitResult(true)) } } } impl fmt::Debug for BarrierWaitResult { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BarrierWaitResult") .field("is_leader", &self.is_leader()) .finish() } } impl BarrierWaitResult { /// Returns `true` if this thread is the "leader thread" for the call to /// [`Barrier::wait()`]. /// /// Only one thread will have `true` returned from their result, all other /// threads will have `false` returned. /// /// # Examples /// /// ``` /// use std::sync::Barrier; /// /// let barrier = Barrier::new(1); /// let barrier_wait_result = barrier.wait(); /// println!("{:?}", barrier_wait_result.is_leader()); /// ``` #[must_use] pub(crate) fn is_leader(&self) -> bool { self.0 } }