diff options
Diffstat (limited to 'third_party/rust/tokio/src/task/task_local.rs')
-rw-r--r-- | third_party/rust/tokio/src/task/task_local.rs | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/task/task_local.rs b/third_party/rust/tokio/src/task/task_local.rs new file mode 100644 index 0000000000..949bbca3ee --- /dev/null +++ b/third_party/rust/tokio/src/task/task_local.rs @@ -0,0 +1,302 @@ +use pin_project_lite::pin_project; +use std::cell::RefCell; +use std::error::Error; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, thread}; + +/// Declares a new task-local key of type [`tokio::task::LocalKey`]. +/// +/// # Syntax +/// +/// The macro wraps any number of static declarations and makes them local to the current task. +/// Publicity and attributes for each static is preserved. For example: +/// +/// # Examples +/// +/// ``` +/// # use tokio::task_local; +/// task_local! { +/// pub static ONE: u32; +/// +/// #[allow(unused)] +/// static TWO: f32; +/// } +/// # fn main() {} +/// ``` +/// +/// See [LocalKey documentation][`tokio::task::LocalKey`] for more +/// information. +/// +/// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +macro_rules! task_local { + // empty (base case for the recursion) + () => {}; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + $crate::task_local!($($rest)*); + }; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __task_local_inner { + ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { + $vis static $name: $crate::task::LocalKey<$t> = { + std::thread_local! { + static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None); + } + + $crate::task::LocalKey { inner: __KEY } + }; + }; +} + +/// A key for task-local data. +/// +/// This type is generated by the `task_local!` macro. +/// +/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will +/// _not_ lazily initialize the value on first access. Instead, the +/// value is first initialized when the future containing +/// the task-local is first polled by a futures executor, like Tokio. +/// +/// # Examples +/// +/// ``` +/// # async fn dox() { +/// tokio::task_local! { +/// static NUMBER: u32; +/// } +/// +/// NUMBER.scope(1, async move { +/// assert_eq!(NUMBER.get(), 1); +/// }).await; +/// +/// NUMBER.scope(2, async move { +/// assert_eq!(NUMBER.get(), 2); +/// +/// NUMBER.scope(3, async move { +/// assert_eq!(NUMBER.get(), 3); +/// }).await; +/// }).await; +/// # } +/// ``` +/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub struct LocalKey<T: 'static> { + #[doc(hidden)] + pub inner: thread::LocalKey<RefCell<Option<T>>>, +} + +impl<T: 'static> LocalKey<T> { + /// Sets a value `T` as the task-local value for the future `F`. + /// + /// On completion of `scope`, the task-local will be dropped. + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> + where + F: Future, + { + TaskLocalFuture { + local: self, + slot: Some(value), + future: f, + _pinned: PhantomPinned, + } + } + + /// Sets a value `T` as the task-local value for the closure `F`. + /// + /// On completion of `scope`, the task-local will be dropped. + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.sync_scope(1, || { + /// println!("task local value: {}", NUMBER.get()); + /// }); + /// # } + /// ``` + pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R + where + F: FnOnce() -> R, + { + let scope = TaskLocalFuture { + local: self, + slot: Some(value), + future: (), + _pinned: PhantomPinned, + }; + crate::pin!(scope); + scope.with_task(|_| f()) + } + + /// Accesses the current task-local and runs the provided closure. + /// + /// # Panics + /// + /// This function will panic if not called within the context + /// of a future containing a task-local with the corresponding key. + pub fn with<F, R>(&'static self, f: F) -> R + where + F: FnOnce(&T) -> R, + { + self.try_with(f).expect( + "cannot access a Task Local Storage value \ + without setting it via `LocalKey::set`", + ) + } + + /// Accesses the current task-local and runs the provided closure. + /// + /// If the task-local with the associated key is not present, this + /// method will return an `AccessError`. For a panicking variant, + /// see `with`. + pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> + where + F: FnOnce(&T) -> R, + { + self.inner.with(|v| { + if let Some(val) = v.borrow().as_ref() { + Ok(f(val)) + } else { + Err(AccessError { _private: () }) + } + }) + } +} + +impl<T: Copy + 'static> LocalKey<T> { + /// Returns a copy of the task-local value + /// if the task-local value implements `Copy`. + pub fn get(&'static self) -> T { + self.with(|v| *v) + } +} + +impl<T: 'static> fmt::Debug for LocalKey<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("LocalKey { .. }") + } +} + +pin_project! { + /// A future that sets a value `T` of a task local for the future `F` during + /// its execution. + /// + /// The value of the task-local must be `'static` and will be dropped on the + /// completion of the future. + /// + /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub struct TaskLocalFuture<T, F> + where + T: 'static + { + local: &'static LocalKey<T>, + slot: Option<T>, + #[pin] + future: F, + #[pin] + _pinned: PhantomPinned, + } +} + +impl<T: 'static, F> TaskLocalFuture<T, F> { + fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { + struct Guard<'a, T: 'static> { + local: &'static LocalKey<T>, + slot: &'a mut Option<T>, + prev: Option<T>, + } + + impl<T> Drop for Guard<'_, T> { + fn drop(&mut self) { + let value = self.local.inner.with(|c| c.replace(self.prev.take())); + *self.slot = value; + } + } + + let project = self.project(); + let val = project.slot.take(); + + let prev = project.local.inner.with(|c| c.replace(val)); + + let _guard = Guard { + prev, + slot: project.slot, + local: *project.local, + }; + + f(project.future) + } +} + +impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.with_task(|f| f.poll(cx)) + } +} + +/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct AccessError { + _private: (), +} + +impl fmt::Debug for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AccessError").finish() + } +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt("task-local value not set", f) + } +} + +impl Error for AccessError {} |