diff options
Diffstat (limited to 'third_party/rust/tokio-util/src/sync')
9 files changed, 1360 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token.rs b/third_party/rust/tokio-util/src/sync/cancellation_token.rs new file mode 100644 index 0000000000..2a6ef392bd --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token.rs @@ -0,0 +1,224 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. +pub(crate) mod guard; +mod tree_node; + +use crate::loom::sync::Arc; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use guard::DropGuard; +use pin_project_lite::pin_project; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::select; +/// use tokio_util::sync::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::sleep(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: Arc<tree_node::TreeNode>, +} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFuture<'a> { + cancellation_token: &'a CancellationToken, + #[pin] + future: tokio::sync::futures::Notified<'a>, + } +} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + tree_node::increase_handle_refcount(&self.inner); + CancellationToken { + inner: self.inner.clone(), + } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + tree_node::decrease_handle_refcount(&self.inner); + } +} + +impl Default for CancellationToken { + fn default() -> CancellationToken { + CancellationToken::new() + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + CancellationToken { + inner: Arc::new(tree_node::TreeNode::new()), + } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::select; + /// use tokio_util::sync::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: tree_node::child_node(&self.inner), + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + /// + /// Be aware that cancellation is not an atomic operation. It is possible + /// for another thread running in parallel with a call to `cancel` to first + /// receive `true` from `is_cancelled` on one child node, and then receive + /// `false` from `is_cancelled` on another child node. However, once the + /// call to `cancel` returns, all child nodes have been fully cancelled. + pub fn cancel(&self) { + tree_node::cancel(&self.inner); + } + + /// Returns `true` if the `CancellationToken` is cancelled. + pub fn is_cancelled(&self) -> bool { + tree_node::is_cancelled(&self.inner) + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: self, + future: self.inner.notified(), + } + } + + /// Creates a `DropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> DropGuard { + DropGuard { inner: Some(self) } + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + this.future.set(this.cancellation_token.inner.notified()); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs new file mode 100644 index 0000000000..54ed7ea2ed --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs @@ -0,0 +1,27 @@ +use crate::sync::CancellationToken; + +/// A wrapper for cancellation token which automatically cancels +/// it on drop. It is created using `drop_guard` method on the `CancellationToken`. +#[derive(Debug)] +pub struct DropGuard { + pub(super) inner: Option<CancellationToken>, +} + +impl DropGuard { + /// Returns stored cancellation token and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(mut self) -> CancellationToken { + self.inner + .take() + .expect("`inner` can be only None in a destructor") + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(inner) = &self.inner { + inner.cancel(); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs new file mode 100644 index 0000000000..b6cd698e23 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs @@ -0,0 +1,373 @@ +//! This mod provides the logic for the inner tree structure of the CancellationToken. +//! +//! CancellationTokens are only light handles with references to TreeNode. +//! All the logic is actually implemented in the TreeNode. +//! +//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of +//! children. +//! +//! A TreeNode can receive the request to perform a cancellation through a CancellationToken. +//! This cancellation request will cancel the node and all of its descendants. +//! +//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no +//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the +//! tree as small as possible. +//! +//! # Invariants +//! +//! Those invariants shall be true at any time. +//! +//! 1. A node that has no parents and no handles can no longer be cancelled. +//! This is important during both cancellation and refcounting. +//! +//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. +//! This is important for deadlock safety, as it is used for lock order. +//! Node B can only become the child of node A in two ways: +//! - being created with `child_node()`, in which case it is trivially true that +//! node A already existed when node B was created +//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` +//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C +//! was younger than A, therefore B is also younger than A. +//! +//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of +//! node A. It is important to always restore that invariant before dropping the lock of a node. +//! +//! # Deadlock safety +//! +//! We always lock in the order of creation time. We can prove this through invariant #2. +//! Specifically, through invariant #2, we know that we always have to lock a parent +//! before its child. +//! +use crate::loom::sync::{Arc, Mutex, MutexGuard}; + +/// A node of the cancellation tree structure +/// +/// The actual data it holds is wrapped inside a mutex for synchronization. +pub(crate) struct TreeNode { + inner: Mutex<Inner>, + waker: tokio::sync::Notify, +} +impl TreeNode { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + } + } + + pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { + self.waker.notified() + } +} + +/// The data contained inside a TreeNode. +/// +/// This struct exists so that the data of the node can be wrapped +/// in a Mutex. +struct Inner { + parent: Option<Arc<TreeNode>>, + parent_idx: usize, + children: Vec<Arc<TreeNode>>, + is_cancelled: bool, + num_handles: usize, +} + +/// Returns whether or not the node is cancelled +pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool { + node.inner.lock().unwrap().is_cancelled +} + +/// Creates a child node +pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> { + let mut locked_parent = parent.inner.lock().unwrap(); + + // Do not register as child if we are already cancelled. + // Cancelled trees can never be uncancelled and therefore + // need no connection to parents or children any more. + if locked_parent.is_cancelled { + return Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: true, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + } + + let child = Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: Some(parent.clone()), + parent_idx: locked_parent.children.len(), + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + + locked_parent.children.push(child.clone()); + + child +} + +/// Disconnects the given parent from all of its children. +/// +/// Takes a reference to [Inner] to make sure the parent is already locked. +fn disconnect_children(node: &mut Inner) { + for child in std::mem::take(&mut node.children) { + let mut locked_child = child.inner.lock().unwrap(); + locked_child.parent_idx = 0; + locked_child.parent = None; + } +} + +/// Figures out the parent of the node and locks the node and its parent atomically. +/// +/// The basic principle of preventing deadlocks in the tree is +/// that we always lock the parent first, and then the child. +/// For more info look at *deadlock safety* and *invariant #2*. +/// +/// Sadly, it's impossible to figure out the parent of a node without +/// locking it. To then achieve locking order consistency, the node +/// has to be unlocked before the parent gets locked. +/// This leaves a small window where we already assume that we know the parent, +/// but neither the parent nor the node is locked. Therefore, the parent could change. +/// +/// To prevent that this problem leaks into the rest of the code, it is abstracted +/// in this function. +/// +/// The locked child and optionally its locked parent, if a parent exists, get passed +/// to the `func` argument via (node, None) or (node, Some(parent)). +fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret +where + F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret, +{ + let mut potential_parent = { + let locked_node = node.inner.lock().unwrap(); + match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => return func(locked_node, None), + } + }; + + loop { + // Deadlock safety: + // + // Due to invariant #2, we know that we have to lock the parent first, and then the child. + // This is true even if the potential_parent is no longer the current parent or even its + // sibling, as the invariant still holds. + let locked_parent = potential_parent.inner.lock().unwrap(); + let locked_node = node.inner.lock().unwrap(); + + let actual_parent = match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => { + // Was the wrong parent, so unlock it before calling `func` + drop(locked_parent); + return func(locked_node, None); + } + }; + + // Loop until we managed to lock both the node and its parent + if Arc::ptr_eq(&actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } + + // Drop locked_parent before reassigning to potential_parent, + // as potential_parent is borrowed in it + drop(locked_node); + drop(locked_parent); + + potential_parent = actual_parent; + } +} + +/// Moves all children from `node` to `parent`. +/// +/// `parent` MUST have been a parent of the node when they both got locked, +/// otherwise there is a potential for a deadlock as invariant #2 would be violated. +/// +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { + // Pre-allocate in the parent, for performance + parent.children.reserve(node.children.len()); + + for child in std::mem::take(&mut node.children) { + { + let mut child_locked = child.inner.lock().unwrap(); + child_locked.parent = node.parent.clone(); + child_locked.parent_idx = parent.children.len(); + } + parent.children.push(child); + } +} + +/// Removes a child from the parent. +/// +/// `parent` MUST be the parent of `node`. +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { + // Query the position from where to remove a node + let pos = node.parent_idx; + node.parent = None; + node.parent_idx = 0; + + // Unlock node, so that only one child at a time is locked. + // Otherwise we would violate the lock order (see 'deadlock safety') as we + // don't know the creation order of the child nodes + drop(node); + + // If `node` is the last element in the list, we don't need any swapping + if parent.children.len() == pos + 1 { + parent.children.pop().unwrap(); + } else { + // If `node` is not the last element in the list, we need to + // replace it with the last element + let replacement_child = parent.children.pop().unwrap(); + replacement_child.inner.lock().unwrap().parent_idx = pos; + parent.children[pos] = replacement_child; + } + + let len = parent.children.len(); + if 4 * len <= parent.children.capacity() { + // equal to: + // parent.children.shrink_to(2 * len); + // but shrink_to was not yet stabilized in our minimal compatible version + let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len)); + parent.children.extend(old_children); + } +} + +/// Increases the reference count of handles. +pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + // Once no handles are left over, the node gets detached from the tree. + // There should never be a new handle once all handles are dropped. + assert!(locked_node.num_handles > 0); + + locked_node.num_handles += 1; +} + +/// Decreases the reference count of handles. +/// +/// Once no handle is left, we can remove the node from the +/// tree and connect its parent directly to its children. +pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) { + let num_handles = { + let mut locked_node = node.inner.lock().unwrap(); + locked_node.num_handles -= 1; + locked_node.num_handles + }; + + if num_handles == 0 { + with_locked_node_and_parent(node, |mut node, parent| { + // Remove the node from the tree + match parent { + Some(mut parent) => { + // As we want to remove ourselves from the tree, + // we have to move the children to the parent, so that + // they still receive the cancellation event without us. + // Moving them does not violate invariant #1. + move_children_to_parent(&mut node, &mut parent); + + // Remove the node from the parent + remove_child(&mut parent, node); + } + None => { + // Due to invariant #1, we can assume that our + // children can no longer be cancelled through us. + // (as we now have neither a parent nor handles) + // Therefore we can disconnect them. + disconnect_children(&mut node); + } + } + }); + } +} + +/// Cancels a node and its children. +pub(crate) fn cancel(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + if locked_node.is_cancelled { + return; + } + + // One by one, adopt grandchildren and then cancel and detach the child + while let Some(child) = locked_node.children.pop() { + // This can't deadlock because the mutex we are already + // holding is the parent of child. + let mut locked_child = child.inner.lock().unwrap(); + + // Detach the child from node + // No need to modify node.children, as the child already got removed with `.pop` + locked_child.parent = None; + locked_child.parent_idx = 0; + + // If child is already cancelled, detaching is enough + if locked_child.is_cancelled { + continue; + } + + // Cancel or adopt grandchildren + while let Some(grandchild) = locked_child.children.pop() { + // This can't deadlock because the two mutexes we are already + // holding is the parent and grandparent of grandchild. + let mut locked_grandchild = grandchild.inner.lock().unwrap(); + + // Detach the grandchild + locked_grandchild.parent = None; + locked_grandchild.parent_idx = 0; + + // If grandchild is already cancelled, detaching is enough + if locked_grandchild.is_cancelled { + continue; + } + + // For performance reasons, only adopt grandchildren that have children. + // Otherwise, just cancel them right away, no need for another iteration. + if locked_grandchild.children.is_empty() { + // Cancel the grandchild + locked_grandchild.is_cancelled = true; + locked_grandchild.children = Vec::new(); + drop(locked_grandchild); + grandchild.waker.notify_waiters(); + } else { + // Otherwise, adopt grandchild + locked_grandchild.parent = Some(node.clone()); + locked_grandchild.parent_idx = locked_node.children.len(); + drop(locked_grandchild); + locked_node.children.push(grandchild); + } + } + + // Cancel the child + locked_child.is_cancelled = true; + locked_child.children = Vec::new(); + drop(locked_child); + child.waker.notify_waiters(); + + // Now the child is cancelled and detached and all its children are adopted. + // Just continue until all (including adopted) children are cancelled and detached. + } + + // Cancel the node itself. + locked_node.is_cancelled = true; + locked_node.children = Vec::new(); + drop(locked_node); + node.waker.notify_waiters(); +} diff --git a/third_party/rust/tokio-util/src/sync/mod.rs b/third_party/rust/tokio-util/src/sync/mod.rs new file mode 100644 index 0000000000..de392f0bb1 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mod.rs @@ -0,0 +1,13 @@ +//! Synchronization primitives + +mod cancellation_token; +pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancellationFuture}; + +mod mpsc; +pub use mpsc::{PollSendError, PollSender}; + +mod poll_semaphore; +pub use poll_semaphore::PollSemaphore; + +mod reusable_box; +pub use reusable_box::ReusableBoxFuture; diff --git a/third_party/rust/tokio-util/src/sync/mpsc.rs b/third_party/rust/tokio-util/src/sync/mpsc.rs new file mode 100644 index 0000000000..34a47c1891 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mpsc.rs @@ -0,0 +1,283 @@ +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, mem}; +use tokio::sync::mpsc::OwnedPermit; +use tokio::sync::mpsc::Sender; + +use super::ReusableBoxFuture; + +/// Error returned by the `PollSender` when the channel is closed. +#[derive(Debug)] +pub struct PollSendError<T>(Option<T>); + +impl<T> PollSendError<T> { + /// Consumes the stored value, if any. + /// + /// If this error was encountered when calling `start_send`/`send_item`, this will be the item + /// that the caller attempted to send. Otherwise, it will be `None`. + pub fn into_inner(self) -> Option<T> { + self.0 + } +} + +impl<T> fmt::Display for PollSendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} + +#[derive(Debug)] +enum State<T> { + Idle(Sender<T>), + Acquiring, + ReadyToSend(OwnedPermit<T>), + Closed, +} + +/// A wrapper around [`mpsc::Sender`] that can be polled. +/// +/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender +#[derive(Debug)] +pub struct PollSender<T> { + sender: Option<Sender<T>>, + state: State<T>, + acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>, +} + +// Creates a future for acquiring a permit from the underlying channel. This is used to ensure +// there's capacity for a send to complete. +// +// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to +// ReusableBoxFuture has the same underlying type, and hence the same size and alignment. +async fn make_acquire_future<T>( + data: Option<Sender<T>>, +) -> Result<OwnedPermit<T>, PollSendError<T>> { + match data { + Some(sender) => sender + .reserve_owned() + .await + .map_err(|_| PollSendError(None)), + None => unreachable!("this future should not be pollable in this state"), + } +} + +impl<T: Send + 'static> PollSender<T> { + /// Creates a new `PollSender`. + pub fn new(sender: Sender<T>) -> Self { + Self { + sender: Some(sender.clone()), + state: State::Idle(sender), + acquire: ReusableBoxFuture::new(make_acquire_future(None)), + } + } + + fn take_state(&mut self) -> State<T> { + mem::replace(&mut self.state, State::Closed) + } + + /// Attempts to prepare the sender to receive a value. + /// + /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to + /// `send_item`. + /// + /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, + /// by reserving a slot in the channel for the item to be sent. If this method returns + /// `Poll::Pending`, the current task is registered to be notified (via + /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { + loop { + let (result, next_state) = match self.take_state() { + State::Idle(sender) => { + // Start trying to acquire a permit to reserve a slot for our send, and + // immediately loop back around to poll it the first time. + self.acquire.set(make_acquire_future(Some(sender))); + (None, State::Acquiring) + } + State::Acquiring => match self.acquire.poll(cx) { + // Channel has capacity. + Poll::Ready(Ok(permit)) => { + (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) + } + // Channel is closed. + Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), + // Channel doesn't have capacity yet, so we need to wait. + Poll::Pending => (Some(Poll::Pending), State::Acquiring), + }, + // We're closed, either by choice or because the underlying sender was closed. + s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), + // We're already ready to send an item. + s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), + }; + + self.state = next_state; + if let Some(result) = result { + return result; + } + } + } + + /// Sends an item to the channel. + /// + /// Before calling `send_item`, `poll_reserve` must be called with a successful return + /// value of `Poll::Ready(Ok(()))`. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + /// + /// # Panics + /// + /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method + /// will panic. + pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { + let (result, next_state) = match self.take_state() { + State::Idle(_) | State::Acquiring => { + panic!("`send_item` called without first calling `poll_reserve`") + } + // We have a permit to send our item, so go ahead, which gets us our sender back. + State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), + // We're closed, either by choice or because the underlying sender was closed. + State::Closed => (Err(PollSendError(Some(value))), State::Closed), + }; + + // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. + self.state = if self.sender.is_some() { + next_state + } else { + State::Closed + }; + result + } + + /// Checks whether this sender is been closed. + /// + /// The underlying channel that this sender was wrapping may still be open. + pub fn is_closed(&self) -> bool { + matches!(self.state, State::Closed) || self.sender.is_none() + } + + /// Gets a reference to the `Sender` of the underlying channel. + /// + /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender + /// was wrapping may still be open. + pub fn get_ref(&self) -> Option<&Sender<T>> { + self.sender.as_ref() + } + + /// Closes this sender. + /// + /// No more messages will be able to be sent from this sender, but the underlying channel will + /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. + /// + /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made + /// to `send_item` in order to consume the reserved slot. After that, no further sends will be + /// possible. If you do not intend to send another item, you can release the reserved slot back + /// to the underlying sender by calling [`abort_send`]. + /// + /// [`abort_send`]: crate::sync::PollSender::abort_send + /// [`Receiver`]: tokio::sync::mpsc::Receiver + pub fn close(&mut self) { + // Mark ourselves officially closed by dropping our main sender. + self.sender = None; + + // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly + // transition to the closed state. Otherwise, leave the existing permit in place for the + // caller if they want to complete the send. + match self.state { + State::Idle(_) => self.state = State::Closed, + State::Acquiring => { + self.acquire.set(make_acquire_future(None)); + self.state = State::Closed; + } + _ => {} + } + } + + /// Aborts the current in-progress send, if any. + /// + /// Returns `true` if a send was aborted. If the sender was closed prior to calling + /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be + /// ready to attempt another send. + pub fn abort_send(&mut self) -> bool { + // We may have been closed in the meantime, after a call to `poll_reserve` already + // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the + // closed state when we actually abort a send, rather than resetting ourselves back to idle. + + let (result, next_state) = match self.take_state() { + // We're currently trying to reserve a slot to send into. + State::Acquiring => { + // Replacing the future drops the in-flight one. + self.acquire.set(make_acquire_future(None)); + + // If we haven't closed yet, we have to clone our stored sender since we have no way + // to get it back from the acquire future we just dropped. + let state = match self.sender.clone() { + Some(sender) => State::Idle(sender), + None => State::Closed, + }; + (true, state) + } + // We got the permit. If we haven't closed yet, get the sender back. + State::ReadyToSend(permit) => { + let state = if self.sender.is_some() { + State::Idle(permit.release()) + } else { + State::Closed + }; + (true, state) + } + s => (false, s), + }; + + self.state = next_state; + result + } +} + +impl<T> Clone for PollSender<T> { + /// Clones this `PollSender`. + /// + /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. + fn clone(&self) -> PollSender<T> { + let (sender, state) = match self.sender.clone() { + Some(sender) => (Some(sender.clone()), State::Idle(sender)), + None => (None, State::Closed), + }; + + Self { + sender, + state, + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender<T>`. + acquire: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl<T: Send + 'static> Sink<T> for PollSender<T> { + type Error = PollSendError<T>; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).poll_reserve(cx) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::into_inner(self).send_item(item) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).close(); + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio-util/src/sync/poll_semaphore.rs b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs new file mode 100644 index 0000000000..d0b1dedc27 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs @@ -0,0 +1,136 @@ +use futures_core::{ready, Stream}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +use super::ReusableBoxFuture; + +/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. +/// +/// [`Semaphore`]: tokio::sync::Semaphore +pub struct PollSemaphore { + semaphore: Arc<Semaphore>, + permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>, +} + +impl PollSemaphore { + /// Create a new `PollSemaphore`. + pub fn new(semaphore: Arc<Semaphore>) -> Self { + Self { + semaphore, + permit_fut: None, + } + } + + /// Closes the semaphore. + pub fn close(&self) { + self.semaphore.close() + } + + /// Obtain a clone of the inner semaphore. + pub fn clone_inner(&self) -> Arc<Semaphore> { + self.semaphore.clone() + } + + /// Get back the inner semaphore. + pub fn into_inner(self) -> Arc<Semaphore> { + self.semaphore + } + + /// Poll to acquire a permit from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when a permit becomes available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + let permit_future = match self.permit_fut.as_mut() { + Some(fut) => fut, + None => { + // avoid allocations completely if we can grab a permit immediately + match Arc::clone(&self.semaphore).try_acquire_owned() { + Ok(permit) => return Poll::Ready(Some(permit)), + Err(TryAcquireError::Closed) => return Poll::Ready(None), + Err(TryAcquireError::NoPermits) => {} + } + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + self.permit_fut + .get_or_insert(ReusableBoxFuture::new(next_fut)) + } + }; + + let result = ready!(permit_future.poll(cx)); + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + permit_future.set(next_fut); + + match result { + Ok(permit) => Poll::Ready(Some(permit)), + Err(_closed) => { + self.permit_fut = None; + Poll::Ready(None) + } + } + } + + /// Returns the current number of available permits. + /// + /// This is equivalent to the [`Semaphore::available_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function + /// will panic if the limit is exceeded. + /// + /// This is equivalent to the [`Semaphore::add_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits + pub fn add_permits(&self, n: usize) { + self.semaphore.add_permits(n); + } +} + +impl Stream for PollSemaphore { + type Item = OwnedSemaphorePermit; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + Pin::into_inner(self).poll_acquire(cx) + } +} + +impl Clone for PollSemaphore { + fn clone(&self) -> PollSemaphore { + PollSemaphore::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollSemaphore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollSemaphore") + .field("semaphore", &self.semaphore) + .finish() + } +} + +impl AsRef<Semaphore> for PollSemaphore { + fn as_ref(&self) -> &Semaphore { + &*self.semaphore + } +} diff --git a/third_party/rust/tokio-util/src/sync/reusable_box.rs b/third_party/rust/tokio-util/src/sync/reusable_box.rs new file mode 100644 index 0000000000..3204207db7 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/reusable_box.rs @@ -0,0 +1,148 @@ +use std::alloc::Layout; +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::ptr::{self, NonNull}; +use std::task::{Context, Poll}; +use std::{fmt, panic}; + +/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub struct ReusableBoxFuture<'a, T> { + boxed: NonNull<dyn Future<Output = T> + Send + 'a>, +} + +impl<'a, T> ReusableBoxFuture<'a, T> { + /// Create a new `ReusableBoxFuture<T>` containing the provided future. + pub fn new<F>(future: F) -> Self + where + F: Future<Output = T> + Send + 'a, + { + let boxed: Box<dyn Future<Output = T> + Send + 'a> = Box::new(future); + + let boxed = NonNull::from(Box::leak(boxed)); + + Self { boxed } + } + + /// Replace the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub fn set<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replace the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub fn try_set<F>(&mut self, future: F) -> Result<(), F> + where + F: Future<Output = T> + Send + 'a, + { + // SAFETY: The pointer is not dangling. + let self_layout = { + let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() }; + Layout::for_value(dyn_future) + }; + + if Layout::new::<F>() == self_layout { + // SAFETY: We just checked that the layout of F is correct. + unsafe { + self.set_same_layout(future); + } + + Ok(()) + } else { + Err(future) + } + } + + /// Set the current future. + /// + /// # Safety + /// + /// This function requires that the layout of the provided future is the + /// same as `self.layout`. + unsafe fn set_same_layout<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + // Drop the existing future, catching any panics. + let result = panic::catch_unwind(AssertUnwindSafe(|| { + ptr::drop_in_place(self.boxed.as_ptr()); + })); + + // Overwrite the future behind the pointer. This is safe because the + // allocation was allocated with the same size and alignment as the type F. + let self_ptr: *mut F = self.boxed.as_ptr() as *mut F; + ptr::write(self_ptr, future); + + // Update the vtable of self.boxed. The pointer is not null because we + // just got it from self.boxed, which is not null. + self.boxed = NonNull::new_unchecked(self_ptr); + + // If the old future's destructor panicked, resume unwinding. + match result { + Ok(()) => {} + Err(payload) => { + panic::resume_unwind(payload); + } + } + } + + /// Get a pinned reference to the underlying future. + pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> { + // SAFETY: The user of this box cannot move the box, and we do not move it + // either. + unsafe { Pin::new_unchecked(self.boxed.as_mut()) } + } + + /// Poll the future stored inside this box. + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> { + self.get_pin().poll(cx) + } +} + +impl<T> Future for ReusableBoxFuture<'_, T> { + type Output = T; + + /// Poll the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The future stored inside ReusableBoxFuture<'_, T> must be Send. +unsafe impl<T> Send for ReusableBoxFuture<'_, T> {} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {} + +// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type. +impl<T> Unpin for ReusableBoxFuture<'_, T> {} + +impl<T> Drop for ReusableBoxFuture<'_, T> { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.boxed.as_ptr())); + } + } +} + +impl<T> fmt::Debug for ReusableBoxFuture<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} diff --git a/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 0000000000..e9c9f3dd98 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,155 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_childs() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/third_party/rust/tokio-util/src/sync/tests/mod.rs b/third_party/rust/tokio-util/src/sync/tests/mod.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/mod.rs @@ -0,0 +1 @@ + |