use futures_sink::Sink; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, mem}; use tokio::sync::mpsc::OwnedPermit; use tokio::sync::mpsc::Sender; use super::ReusableBoxFuture; /// Error returned by the `PollSender` when the channel is closed. #[derive(Debug)] pub struct PollSendError(Option); impl PollSendError { /// Consumes the stored value, if any. /// /// If this error was encountered when calling `start_send`/`send_item`, this will be the item /// that the caller attempted to send. Otherwise, it will be `None`. pub fn into_inner(self) -> Option { self.0 } } impl fmt::Display for PollSendError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } impl std::error::Error for PollSendError {} #[derive(Debug)] enum State { Idle(Sender), Acquiring, ReadyToSend(OwnedPermit), Closed, } /// A wrapper around [`mpsc::Sender`] that can be polled. /// /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender #[derive(Debug)] pub struct PollSender { sender: Option>, state: State, acquire: ReusableBoxFuture<'static, Result, PollSendError>>, } // Creates a future for acquiring a permit from the underlying channel. This is used to ensure // there's capacity for a send to complete. // // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. async fn make_acquire_future( data: Option>, ) -> Result, PollSendError> { match data { Some(sender) => sender .reserve_owned() .await .map_err(|_| PollSendError(None)), None => unreachable!("this future should not be pollable in this state"), } } impl PollSender { /// Creates a new `PollSender`. pub fn new(sender: Sender) -> Self { Self { sender: Some(sender.clone()), state: State::Idle(sender), acquire: ReusableBoxFuture::new(make_acquire_future(None)), } } fn take_state(&mut self) -> State { mem::replace(&mut self.state, State::Closed) } /// Attempts to prepare the sender to receive a value. /// /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to /// `send_item`. /// /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, /// by reserving a slot in the channel for the item to be sent. If this method returns /// `Poll::Pending`, the current task is registered to be notified (via /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. /// /// # Errors /// /// If the channel is closed, an error will be returned. This is a permanent state. pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll>> { loop { let (result, next_state) = match self.take_state() { State::Idle(sender) => { // Start trying to acquire a permit to reserve a slot for our send, and // immediately loop back around to poll it the first time. self.acquire.set(make_acquire_future(Some(sender))); (None, State::Acquiring) } State::Acquiring => match self.acquire.poll(cx) { // Channel has capacity. Poll::Ready(Ok(permit)) => { (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) } // Channel is closed. Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), // Channel doesn't have capacity yet, so we need to wait. Poll::Pending => (Some(Poll::Pending), State::Acquiring), }, // We're closed, either by choice or because the underlying sender was closed. s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), // We're already ready to send an item. s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), }; self.state = next_state; if let Some(result) = result { return result; } } } /// Sends an item to the channel. /// /// Before calling `send_item`, `poll_reserve` must be called with a successful return /// value of `Poll::Ready(Ok(()))`. /// /// # Errors /// /// If the channel is closed, an error will be returned. This is a permanent state. /// /// # Panics /// /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method /// will panic. pub fn send_item(&mut self, value: T) -> Result<(), PollSendError> { let (result, next_state) = match self.take_state() { State::Idle(_) | State::Acquiring => { panic!("`send_item` called without first calling `poll_reserve`") } // We have a permit to send our item, so go ahead, which gets us our sender back. State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), // We're closed, either by choice or because the underlying sender was closed. State::Closed => (Err(PollSendError(Some(value))), State::Closed), }; // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. self.state = if self.sender.is_some() { next_state } else { State::Closed }; result } /// Checks whether this sender is been closed. /// /// The underlying channel that this sender was wrapping may still be open. pub fn is_closed(&self) -> bool { matches!(self.state, State::Closed) || self.sender.is_none() } /// Gets a reference to the `Sender` of the underlying channel. /// /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender /// was wrapping may still be open. pub fn get_ref(&self) -> Option<&Sender> { self.sender.as_ref() } /// Closes this sender. /// /// No more messages will be able to be sent from this sender, but the underlying channel will /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. /// /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made /// to `send_item` in order to consume the reserved slot. After that, no further sends will be /// possible. If you do not intend to send another item, you can release the reserved slot back /// to the underlying sender by calling [`abort_send`]. /// /// [`abort_send`]: crate::sync::PollSender::abort_send /// [`Receiver`]: tokio::sync::mpsc::Receiver pub fn close(&mut self) { // Mark ourselves officially closed by dropping our main sender. self.sender = None; // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly // transition to the closed state. Otherwise, leave the existing permit in place for the // caller if they want to complete the send. match self.state { State::Idle(_) => self.state = State::Closed, State::Acquiring => { self.acquire.set(make_acquire_future(None)); self.state = State::Closed; } _ => {} } } /// Aborts the current in-progress send, if any. /// /// Returns `true` if a send was aborted. If the sender was closed prior to calling /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be /// ready to attempt another send. pub fn abort_send(&mut self) -> bool { // We may have been closed in the meantime, after a call to `poll_reserve` already // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the // closed state when we actually abort a send, rather than resetting ourselves back to idle. let (result, next_state) = match self.take_state() { // We're currently trying to reserve a slot to send into. State::Acquiring => { // Replacing the future drops the in-flight one. self.acquire.set(make_acquire_future(None)); // If we haven't closed yet, we have to clone our stored sender since we have no way // to get it back from the acquire future we just dropped. let state = match self.sender.clone() { Some(sender) => State::Idle(sender), None => State::Closed, }; (true, state) } // We got the permit. If we haven't closed yet, get the sender back. State::ReadyToSend(permit) => { let state = if self.sender.is_some() { State::Idle(permit.release()) } else { State::Closed }; (true, state) } s => (false, s), }; self.state = next_state; result } } impl Clone for PollSender { /// Clones this `PollSender`. /// /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. fn clone(&self) -> PollSender { let (sender, state) = match self.sender.clone() { Some(sender) => (Some(sender.clone()), State::Idle(sender)), None => (None, State::Closed), }; Self { sender, state, // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not // compatible with the transitive bounds required by `Sender`. acquire: ReusableBoxFuture::new(async { unreachable!() }), } } } impl Sink for PollSender { type Error = PollSendError; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_reserve(cx) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { Pin::into_inner(self).send_item(item) } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).close(); Poll::Ready(Ok(())) } }