use super::{IntoIter, IterMut, ThreadLocal}; use std::cell::UnsafeCell; use std::fmt; use std::panic::UnwindSafe; use std::sync::atomic::{AtomicUsize, Ordering}; use thread_id; use unreachable::{UncheckedOptionExt, UncheckedResultExt}; /// Wrapper around `ThreadLocal` which adds a fast path for a single thread. /// /// This has the same API as `ThreadLocal`, but will register the first thread /// that sets a value as its owner. All accesses by the owner will go through /// a special fast path which is much faster than the normal `ThreadLocal` path. pub struct CachedThreadLocal { owner: AtomicUsize, local: UnsafeCell>>, global: ThreadLocal, } // CachedThreadLocal is always Sync, even if T isn't unsafe impl Sync for CachedThreadLocal {} impl Default for CachedThreadLocal { fn default() -> CachedThreadLocal { CachedThreadLocal::new() } } impl CachedThreadLocal { /// Creates a new empty `CachedThreadLocal`. pub fn new() -> CachedThreadLocal { CachedThreadLocal { owner: AtomicUsize::new(0), local: UnsafeCell::new(None), global: ThreadLocal::new(), } } /// Returns the element for the current thread, if it exists. pub fn get(&self) -> Option<&T> { let id = thread_id::get(); let owner = self.owner.load(Ordering::Relaxed); if owner == id { return unsafe { Some((*self.local.get()).as_ref().unchecked_unwrap()) }; } if owner == 0 { return None; } self.global.get_fast(id) } /// Returns the element for the current thread, or creates it if it doesn't /// exist. #[inline(always)] pub fn get_or(&self, create: F) -> &T where F: FnOnce() -> T, { unsafe { self.get_or_try(|| Ok::(create())) .unchecked_unwrap_ok() } } /// Returns the element for the current thread, or creates it if it doesn't /// exist. If `create` fails, that error is returned and no element is /// added. pub fn get_or_try(&self, create: F) -> Result<&T, E> where F: FnOnce() -> Result, { let id = thread_id::get(); let owner = self.owner.load(Ordering::Relaxed); if owner == id { return Ok(unsafe { (*self.local.get()).as_ref().unchecked_unwrap() }); } self.get_or_try_slow(id, owner, create) } #[cold] #[inline(never)] fn get_or_try_slow(&self, id: usize, owner: usize, create: F) -> Result<&T, E> where F: FnOnce() -> Result, { if owner == 0 && self.owner.compare_and_swap(0, id, Ordering::Relaxed) == 0 { unsafe { (*self.local.get()) = Some(Box::new(create()?)); return Ok((*self.local.get()).as_ref().unchecked_unwrap()); } } match self.global.get_fast(id) { Some(x) => Ok(x), None => Ok(self.global.insert(id, Box::new(create()?), true)), } } /// Returns a mutable iterator over the local values of all threads. /// /// Since this call borrows the `ThreadLocal` mutably, this operation can /// be done safely---the mutable borrow statically guarantees no other /// threads are currently accessing their associated values. pub fn iter_mut(&mut self) -> CachedIterMut { CachedIterMut { local: unsafe { (*self.local.get()).as_mut().map(|x| &mut **x) }, global: self.global.iter_mut(), } } /// Removes all thread-specific values from the `ThreadLocal`, effectively /// reseting it to its original state. /// /// Since this call borrows the `ThreadLocal` mutably, this operation can /// be done safely---the mutable borrow statically guarantees no other /// threads are currently accessing their associated values. pub fn clear(&mut self) { *self = CachedThreadLocal::new(); } } impl IntoIterator for CachedThreadLocal { type Item = T; type IntoIter = CachedIntoIter; fn into_iter(self) -> CachedIntoIter { CachedIntoIter { local: unsafe { (*self.local.get()).take().map(|x| *x) }, global: self.global.into_iter(), } } } impl<'a, T: Send + 'a> IntoIterator for &'a mut CachedThreadLocal { type Item = &'a mut T; type IntoIter = CachedIterMut<'a, T>; fn into_iter(self) -> CachedIterMut<'a, T> { self.iter_mut() } } impl CachedThreadLocal { /// Returns the element for the current thread, or creates a default one if /// it doesn't exist. pub fn get_or_default(&self) -> &T { self.get_or(T::default) } } impl fmt::Debug for CachedThreadLocal { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) } } impl UnwindSafe for CachedThreadLocal {} /// Mutable iterator over the contents of a `CachedThreadLocal`. pub struct CachedIterMut<'a, T: Send + 'a> { local: Option<&'a mut T>, global: IterMut<'a, T>, } impl<'a, T: Send + 'a> Iterator for CachedIterMut<'a, T> { type Item = &'a mut T; fn next(&mut self) -> Option<&'a mut T> { self.local.take().or_else(|| self.global.next()) } fn size_hint(&self) -> (usize, Option) { let len = self.global.size_hint().0 + self.local.is_some() as usize; (len, Some(len)) } } impl<'a, T: Send + 'a> ExactSizeIterator for CachedIterMut<'a, T> {} /// An iterator that moves out of a `CachedThreadLocal`. pub struct CachedIntoIter { local: Option, global: IntoIter, } impl Iterator for CachedIntoIter { type Item = T; fn next(&mut self) -> Option { self.local.take().or_else(|| self.global.next()) } fn size_hint(&self) -> (usize, Option) { let len = self.global.size_hint().0 + self.local.is_some() as usize; (len, Some(len)) } } impl ExactSizeIterator for CachedIntoIter {}