summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio-util/src/sync/mpsc.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio-util/src/sync/mpsc.rs')
-rw-r--r--third_party/rust/tokio-util/src/sync/mpsc.rs283
1 files changed, 283 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/src/sync/mpsc.rs b/third_party/rust/tokio-util/src/sync/mpsc.rs
new file mode 100644
index 0000000000..34a47c1891
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/mpsc.rs
@@ -0,0 +1,283 @@
+use futures_sink::Sink;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{fmt, mem};
+use tokio::sync::mpsc::OwnedPermit;
+use tokio::sync::mpsc::Sender;
+
+use super::ReusableBoxFuture;
+
+/// Error returned by the `PollSender` when the channel is closed.
+#[derive(Debug)]
+pub struct PollSendError<T>(Option<T>);
+
+impl<T> PollSendError<T> {
+ /// Consumes the stored value, if any.
+ ///
+ /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
+ /// that the caller attempted to send. Otherwise, it will be `None`.
+ pub fn into_inner(self) -> Option<T> {
+ self.0
+ }
+}
+
+impl<T> fmt::Display for PollSendError<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+}
+
+impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
+
+#[derive(Debug)]
+enum State<T> {
+ Idle(Sender<T>),
+ Acquiring,
+ ReadyToSend(OwnedPermit<T>),
+ Closed,
+}
+
+/// A wrapper around [`mpsc::Sender`] that can be polled.
+///
+/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
+#[derive(Debug)]
+pub struct PollSender<T> {
+ sender: Option<Sender<T>>,
+ state: State<T>,
+ acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>,
+}
+
+// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
+// there's capacity for a send to complete.
+//
+// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
+// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
+async fn make_acquire_future<T>(
+ data: Option<Sender<T>>,
+) -> Result<OwnedPermit<T>, PollSendError<T>> {
+ match data {
+ Some(sender) => sender
+ .reserve_owned()
+ .await
+ .map_err(|_| PollSendError(None)),
+ None => unreachable!("this future should not be pollable in this state"),
+ }
+}
+
+impl<T: Send + 'static> PollSender<T> {
+ /// Creates a new `PollSender`.
+ pub fn new(sender: Sender<T>) -> Self {
+ Self {
+ sender: Some(sender.clone()),
+ state: State::Idle(sender),
+ acquire: ReusableBoxFuture::new(make_acquire_future(None)),
+ }
+ }
+
+ fn take_state(&mut self) -> State<T> {
+ mem::replace(&mut self.state, State::Closed)
+ }
+
+ /// Attempts to prepare the sender to receive a value.
+ ///
+ /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
+ /// `send_item`.
+ ///
+ /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
+ /// by reserving a slot in the channel for the item to be sent. If this method returns
+ /// `Poll::Pending`, the current task is registered to be notified (via
+ /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
+ loop {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(sender) => {
+ // Start trying to acquire a permit to reserve a slot for our send, and
+ // immediately loop back around to poll it the first time.
+ self.acquire.set(make_acquire_future(Some(sender)));
+ (None, State::Acquiring)
+ }
+ State::Acquiring => match self.acquire.poll(cx) {
+ // Channel has capacity.
+ Poll::Ready(Ok(permit)) => {
+ (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
+ }
+ // Channel is closed.
+ Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
+ // Channel doesn't have capacity yet, so we need to wait.
+ Poll::Pending => (Some(Poll::Pending), State::Acquiring),
+ },
+ // We're closed, either by choice or because the underlying sender was closed.
+ s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
+ // We're already ready to send an item.
+ s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
+ };
+
+ self.state = next_state;
+ if let Some(result) = result {
+ return result;
+ }
+ }
+ }
+
+ /// Sends an item to the channel.
+ ///
+ /// Before calling `send_item`, `poll_reserve` must be called with a successful return
+ /// value of `Poll::Ready(Ok(()))`.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ ///
+ /// # Panics
+ ///
+ /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
+ /// will panic.
+ pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(_) | State::Acquiring => {
+ panic!("`send_item` called without first calling `poll_reserve`")
+ }
+ // We have a permit to send our item, so go ahead, which gets us our sender back.
+ State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
+ // We're closed, either by choice or because the underlying sender was closed.
+ State::Closed => (Err(PollSendError(Some(value))), State::Closed),
+ };
+
+ // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
+ self.state = if self.sender.is_some() {
+ next_state
+ } else {
+ State::Closed
+ };
+ result
+ }
+
+ /// Checks whether this sender is been closed.
+ ///
+ /// The underlying channel that this sender was wrapping may still be open.
+ pub fn is_closed(&self) -> bool {
+ matches!(self.state, State::Closed) || self.sender.is_none()
+ }
+
+ /// Gets a reference to the `Sender` of the underlying channel.
+ ///
+ /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
+ /// was wrapping may still be open.
+ pub fn get_ref(&self) -> Option<&Sender<T>> {
+ self.sender.as_ref()
+ }
+
+ /// Closes this sender.
+ ///
+ /// No more messages will be able to be sent from this sender, but the underlying channel will
+ /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
+ ///
+ /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
+ /// to `send_item` in order to consume the reserved slot. After that, no further sends will be
+ /// possible. If you do not intend to send another item, you can release the reserved slot back
+ /// to the underlying sender by calling [`abort_send`].
+ ///
+ /// [`abort_send`]: crate::sync::PollSender::abort_send
+ /// [`Receiver`]: tokio::sync::mpsc::Receiver
+ pub fn close(&mut self) {
+ // Mark ourselves officially closed by dropping our main sender.
+ self.sender = None;
+
+ // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
+ // transition to the closed state. Otherwise, leave the existing permit in place for the
+ // caller if they want to complete the send.
+ match self.state {
+ State::Idle(_) => self.state = State::Closed,
+ State::Acquiring => {
+ self.acquire.set(make_acquire_future(None));
+ self.state = State::Closed;
+ }
+ _ => {}
+ }
+ }
+
+ /// Aborts the current in-progress send, if any.
+ ///
+ /// Returns `true` if a send was aborted. If the sender was closed prior to calling
+ /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
+ /// ready to attempt another send.
+ pub fn abort_send(&mut self) -> bool {
+ // We may have been closed in the meantime, after a call to `poll_reserve` already
+ // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
+ // closed state when we actually abort a send, rather than resetting ourselves back to idle.
+
+ let (result, next_state) = match self.take_state() {
+ // We're currently trying to reserve a slot to send into.
+ State::Acquiring => {
+ // Replacing the future drops the in-flight one.
+ self.acquire.set(make_acquire_future(None));
+
+ // If we haven't closed yet, we have to clone our stored sender since we have no way
+ // to get it back from the acquire future we just dropped.
+ let state = match self.sender.clone() {
+ Some(sender) => State::Idle(sender),
+ None => State::Closed,
+ };
+ (true, state)
+ }
+ // We got the permit. If we haven't closed yet, get the sender back.
+ State::ReadyToSend(permit) => {
+ let state = if self.sender.is_some() {
+ State::Idle(permit.release())
+ } else {
+ State::Closed
+ };
+ (true, state)
+ }
+ s => (false, s),
+ };
+
+ self.state = next_state;
+ result
+ }
+}
+
+impl<T> Clone for PollSender<T> {
+ /// Clones this `PollSender`.
+ ///
+ /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
+ fn clone(&self) -> PollSender<T> {
+ let (sender, state) = match self.sender.clone() {
+ Some(sender) => (Some(sender.clone()), State::Idle(sender)),
+ None => (None, State::Closed),
+ };
+
+ Self {
+ sender,
+ state,
+ // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
+ // compatible with the transitive bounds required by `Sender<T>`.
+ acquire: ReusableBoxFuture::new(async { unreachable!() }),
+ }
+ }
+}
+
+impl<T: Send + 'static> Sink<T> for PollSender<T> {
+ type Error = PollSendError<T>;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).poll_reserve(cx)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+ Pin::into_inner(self).send_item(item)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).close();
+ Poll::Ready(Ok(()))
+ }
+}