diff options
Diffstat (limited to 'third_party/rust/once_cell/src/imp_std.rs')
-rw-r--r-- | third_party/rust/once_cell/src/imp_std.rs | 425 |
1 files changed, 425 insertions, 0 deletions
diff --git a/third_party/rust/once_cell/src/imp_std.rs b/third_party/rust/once_cell/src/imp_std.rs new file mode 100644 index 0000000000..5761f0184b --- /dev/null +++ b/third_party/rust/once_cell/src/imp_std.rs @@ -0,0 +1,425 @@ +// There's a lot of scary concurrent code in this module, but it is copied from +// `std::sync::Once` with two changes: +// * no poisoning +// * init function can fail + +use std::{ + cell::{Cell, UnsafeCell}, + marker::PhantomData, + panic::{RefUnwindSafe, UnwindSafe}, + sync::atomic::{AtomicBool, AtomicPtr, Ordering}, + thread::{self, Thread}, +}; + +#[derive(Debug)] +pub(crate) struct OnceCell<T> { + // This `queue` field is the core of the implementation. It encodes two + // pieces of information: + // + // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) + // * Linked list of threads waiting for the current cell. + // + // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states + // allow waiters. + queue: AtomicPtr<Waiter>, + _marker: PhantomData<*mut Waiter>, + value: UnsafeCell<Option<T>>, +} + +// Why do we need `T: Send`? +// Thread A creates a `OnceCell` and shares it with +// scoped thread B, which fills the cell, which is +// then destroyed by A. That is, destructor observes +// a sent value. +unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} +unsafe impl<T: Send> Send for OnceCell<T> {} + +impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} +impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} + +impl<T> OnceCell<T> { + pub(crate) const fn new() -> OnceCell<T> { + OnceCell { + queue: AtomicPtr::new(INCOMPLETE_PTR), + _marker: PhantomData, + value: UnsafeCell::new(None), + } + } + + pub(crate) const fn with_value(value: T) -> OnceCell<T> { + OnceCell { + queue: AtomicPtr::new(COMPLETE_PTR), + _marker: PhantomData, + value: UnsafeCell::new(Some(value)), + } + } + + /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). + #[inline] + pub(crate) fn is_initialized(&self) -> bool { + // An `Acquire` load is enough because that makes all the initialization + // operations visible to us, and, this being a fast path, weaker + // ordering helps with performance. This `Acquire` synchronizes with + // `SeqCst` operations on the slow path. + self.queue.load(Ordering::Acquire) == COMPLETE_PTR + } + + /// Safety: synchronizes with store to value via SeqCst read from state, + /// writes value only once because we never get to INCOMPLETE state after a + /// successful write. + #[cold] + pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> + where + F: FnOnce() -> Result<T, E>, + { + let mut f = Some(f); + let mut res: Result<(), E> = Ok(()); + let slot: *mut Option<T> = self.value.get(); + initialize_or_wait( + &self.queue, + Some(&mut || { + let f = unsafe { crate::unwrap_unchecked(f.take()) }; + match f() { + Ok(value) => { + unsafe { *slot = Some(value) }; + true + } + Err(err) => { + res = Err(err); + false + } + } + }), + ); + res + } + + #[cold] + pub(crate) fn wait(&self) { + initialize_or_wait(&self.queue, None); + } + + /// Get the reference to the underlying value, without checking if the cell + /// is initialized. + /// + /// # Safety + /// + /// Caller must ensure that the cell is in initialized state, and that + /// the contents are acquired by (synchronized to) this thread. + pub(crate) unsafe fn get_unchecked(&self) -> &T { + debug_assert!(self.is_initialized()); + let slot = &*self.value.get(); + crate::unwrap_unchecked(slot.as_ref()) + } + + /// Gets the mutable reference to the underlying value. + /// Returns `None` if the cell is empty. + pub(crate) fn get_mut(&mut self) -> Option<&mut T> { + // Safe b/c we have a unique access. + unsafe { &mut *self.value.get() }.as_mut() + } + + /// Consumes this `OnceCell`, returning the wrapped value. + /// Returns `None` if the cell was empty. + #[inline] + pub(crate) fn into_inner(self) -> Option<T> { + // Because `into_inner` takes `self` by value, the compiler statically + // verifies that it is not currently borrowed. + // So, it is safe to move out `Option<T>`. + self.value.into_inner() + } +} + +// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in +// the OnceCell structure. +const INCOMPLETE: usize = 0x0; +const RUNNING: usize = 0x1; +const COMPLETE: usize = 0x2; +const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter; +const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter; + +// Mask to learn about the state. All other bits are the queue of waiters if +// this is in the RUNNING state. +const STATE_MASK: usize = 0x3; + +/// Representation of a node in the linked list of waiters in the RUNNING state. +/// A waiters is stored on the stack of the waiting threads. +#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. +struct Waiter { + thread: Cell<Option<Thread>>, + signaled: AtomicBool, + next: *mut Waiter, +} + +/// Drains and notifies the queue of waiters on drop. +struct Guard<'a> { + queue: &'a AtomicPtr<Waiter>, + new_queue: *mut Waiter, +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + let queue = self.queue.swap(self.new_queue, Ordering::AcqRel); + + let state = strict::addr(queue) & STATE_MASK; + assert_eq!(state, RUNNING); + + unsafe { + let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK); + while !waiter.is_null() { + let next = (*waiter).next; + let thread = (*waiter).thread.take().unwrap(); + (*waiter).signaled.store(true, Ordering::Release); + waiter = next; + thread.unpark(); + } + } + } +} + +// Corresponds to `std::sync::Once::call_inner`. +// +// Originally copied from std, but since modified to remove poisoning and to +// support wait. +// +// Note: this is intentionally monomorphic +#[inline(never)] +fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) { + let mut curr_queue = queue.load(Ordering::Acquire); + + loop { + let curr_state = strict::addr(curr_queue) & STATE_MASK; + match (curr_state, &mut init) { + (COMPLETE, _) => return, + (INCOMPLETE, Some(init)) => { + let exchange = queue.compare_exchange( + curr_queue, + strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING), + Ordering::Acquire, + Ordering::Acquire, + ); + if let Err(new_queue) = exchange { + curr_queue = new_queue; + continue; + } + let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR }; + if init() { + guard.new_queue = COMPLETE_PTR; + } + return; + } + (INCOMPLETE, None) | (RUNNING, _) => { + wait(&queue, curr_queue); + curr_queue = queue.load(Ordering::Acquire); + } + _ => debug_assert!(false), + } + } +} + +fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) { + let curr_state = strict::addr(curr_queue) & STATE_MASK; + loop { + let node = Waiter { + thread: Cell::new(Some(thread::current())), + signaled: AtomicBool::new(false), + next: strict::map_addr(curr_queue, |q| q & !STATE_MASK), + }; + let me = &node as *const Waiter as *mut Waiter; + + let exchange = queue.compare_exchange( + curr_queue, + strict::map_addr(me, |q| q | curr_state), + Ordering::Release, + Ordering::Relaxed, + ); + if let Err(new_queue) = exchange { + if strict::addr(new_queue) & STATE_MASK != curr_state { + return; + } + curr_queue = new_queue; + continue; + } + + while !node.signaled.load(Ordering::Acquire) { + thread::park(); + } + break; + } +} + +// Polyfill of strict provenance from https://crates.io/crates/sptr. +// +// Use free-standing function rather than a trait to keep things simple and +// avoid any potential conflicts with future stabile std API. +mod strict { + #[must_use] + #[inline] + pub(crate) fn addr<T>(ptr: *mut T) -> usize + where + T: Sized, + { + // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. + // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the + // provenance). + unsafe { core::mem::transmute(ptr) } + } + + #[must_use] + #[inline] + pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T + where + T: Sized, + { + // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. + // + // In the mean-time, this operation is defined to be "as if" it was + // a wrapping_offset, so we can emulate it as such. This should properly + // restore pointer provenance even under today's compiler. + let self_addr = self::addr(ptr) as isize; + let dest_addr = addr as isize; + let offset = dest_addr.wrapping_sub(self_addr); + + // This is the canonical desugarring of this operation, + // but `pointer::cast` was only stabilized in 1.38. + // self.cast::<u8>().wrapping_offset(offset).cast::<T>() + (ptr as *mut u8).wrapping_offset(offset) as *mut T + } + + #[must_use] + #[inline] + pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T + where + T: Sized, + { + self::with_addr(ptr, f(addr(ptr))) + } +} + +// These test are snatched from std as well. +#[cfg(test)] +mod tests { + use std::panic; + use std::{sync::mpsc::channel, thread}; + + use super::OnceCell; + + impl<T> OnceCell<T> { + fn init(&self, f: impl FnOnce() -> T) { + enum Void {} + let _ = self.initialize(|| Ok::<T, Void>(f())); + } + } + + #[test] + fn smoke_once() { + static O: OnceCell<()> = OnceCell::new(); + let mut a = 0; + O.init(|| a += 1); + assert_eq!(a, 1); + O.init(|| a += 1); + assert_eq!(a, 1); + } + + #[test] + fn stampede_once() { + static O: OnceCell<()> = OnceCell::new(); + static mut RUN: bool = false; + + let (tx, rx) = channel(); + for _ in 0..10 { + let tx = tx.clone(); + thread::spawn(move || { + for _ in 0..4 { + thread::yield_now() + } + unsafe { + O.init(|| { + assert!(!RUN); + RUN = true; + }); + assert!(RUN); + } + tx.send(()).unwrap(); + }); + } + + unsafe { + O.init(|| { + assert!(!RUN); + RUN = true; + }); + assert!(RUN); + } + + for _ in 0..10 { + rx.recv().unwrap(); + } + } + + #[test] + fn poison_bad() { + static O: OnceCell<()> = OnceCell::new(); + + // poison the once + let t = panic::catch_unwind(|| { + O.init(|| panic!()); + }); + assert!(t.is_err()); + + // we can subvert poisoning, however + let mut called = false; + O.init(|| { + called = true; + }); + assert!(called); + + // once any success happens, we stop propagating the poison + O.init(|| {}); + } + + #[test] + fn wait_for_force_to_finish() { + static O: OnceCell<()> = OnceCell::new(); + + // poison the once + let t = panic::catch_unwind(|| { + O.init(|| panic!()); + }); + assert!(t.is_err()); + + // make sure someone's waiting inside the once via a force + let (tx1, rx1) = channel(); + let (tx2, rx2) = channel(); + let t1 = thread::spawn(move || { + O.init(|| { + tx1.send(()).unwrap(); + rx2.recv().unwrap(); + }); + }); + + rx1.recv().unwrap(); + + // put another waiter on the once + let t2 = thread::spawn(|| { + let mut called = false; + O.init(|| { + called = true; + }); + assert!(!called); + }); + + tx2.send(()).unwrap(); + + assert!(t1.join().is_ok()); + assert!(t2.join().is_ok()); + } + + #[test] + #[cfg(target_pointer_width = "64")] + fn test_size() { + use std::mem::size_of; + + assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); + } +} |