use futures_util::future; use tokio::sync::{mpsc, oneshot}; use crate::common::{task, Future, Pin, Poll}; pub type RetryPromise = oneshot::Receiver)>>; pub type Promise = oneshot::Receiver>; pub fn channel() -> (Sender, Receiver) { let (tx, rx) = mpsc::unbounded_channel(); let (giver, taker) = want::new(); let tx = Sender { buffered_once: false, giver, inner: tx, }; let rx = Receiver { inner: rx, taker }; (tx, rx) } /// A bounded sender of requests and callbacks for when responses are ready. /// /// While the inner sender is unbounded, the Giver is used to determine /// if the Receiver is ready for another request. pub struct Sender { /// One message is always allowed, even if the Receiver hasn't asked /// for it yet. This boolean keeps track of whether we've sent one /// without notice. buffered_once: bool, /// The Giver helps watch that the the Receiver side has been polled /// when the queue is empty. This helps us know when a request and /// response have been fully processed, and a connection is ready /// for more. giver: want::Giver, /// Actually bounded by the Giver, plus `buffered_once`. inner: mpsc::UnboundedSender>, } /// An unbounded version. /// /// Cannot poll the Giver, but can still use it to determine if the Receiver /// has been dropped. However, this version can be cloned. pub struct UnboundedSender { /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked. giver: want::SharedGiver, inner: mpsc::UnboundedSender>, } impl Sender { pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { self.giver .poll_want(cx) .map_err(|_| crate::Error::new_closed()) } pub fn is_ready(&self) -> bool { self.giver.is_wanting() } pub fn is_closed(&self) -> bool { self.giver.is_canceled() } fn can_send(&mut self) -> bool { if self.giver.give() || !self.buffered_once { // If the receiver is ready *now*, then of course we can send. // // If the receiver isn't ready yet, but we don't have anything // in the channel yet, then allow one message. self.buffered_once = true; true } else { false } } pub fn try_send(&mut self, val: T) -> Result, T> { if !self.can_send() { return Err(val); } let (tx, rx) = oneshot::channel(); self.inner .send(Envelope(Some((val, Callback::Retry(tx))))) .map(move |_| rx) .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) } pub fn send(&mut self, val: T) -> Result, T> { if !self.can_send() { return Err(val); } let (tx, rx) = oneshot::channel(); self.inner .send(Envelope(Some((val, Callback::NoRetry(tx))))) .map(move |_| rx) .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) } pub fn unbound(self) -> UnboundedSender { UnboundedSender { giver: self.giver.shared(), inner: self.inner, } } } impl UnboundedSender { pub fn is_ready(&self) -> bool { !self.giver.is_canceled() } pub fn is_closed(&self) -> bool { self.giver.is_canceled() } pub fn try_send(&mut self, val: T) -> Result, T> { let (tx, rx) = oneshot::channel(); self.inner .send(Envelope(Some((val, Callback::Retry(tx))))) .map(move |_| rx) .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) } } impl Clone for UnboundedSender { fn clone(&self) -> Self { UnboundedSender { giver: self.giver.clone(), inner: self.inner.clone(), } } } pub struct Receiver { inner: mpsc::UnboundedReceiver>, taker: want::Taker, } impl Receiver { pub(crate) fn poll_next( &mut self, cx: &mut task::Context<'_>, ) -> Poll)>> { match self.inner.poll_recv(cx) { Poll::Ready(item) => { Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped"))) } Poll::Pending => { self.taker.want(); Poll::Pending } } } pub(crate) fn close(&mut self) { self.taker.cancel(); self.inner.close(); } pub(crate) fn try_recv(&mut self) -> Option<(T, Callback)> { match self.inner.try_recv() { Ok(mut env) => env.0.take(), Err(_) => None, } } } impl Drop for Receiver { fn drop(&mut self) { // Notify the giver about the closure first, before dropping // the mpsc::Receiver. self.taker.cancel(); } } struct Envelope(Option<(T, Callback)>); impl Drop for Envelope { fn drop(&mut self) { if let Some((val, cb)) = self.0.take() { cb.send(Err(( crate::Error::new_canceled().with("connection closed"), Some(val), ))); } } } pub enum Callback { Retry(oneshot::Sender)>>), NoRetry(oneshot::Sender>), } impl Callback { pub(crate) fn is_canceled(&self) -> bool { match *self { Callback::Retry(ref tx) => tx.is_closed(), Callback::NoRetry(ref tx) => tx.is_closed(), } } pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> { match *self { Callback::Retry(ref mut tx) => tx.poll_closed(cx), Callback::NoRetry(ref mut tx) => tx.poll_closed(cx), } } pub(crate) fn send(self, val: Result)>) { match self { Callback::Retry(tx) => { let _ = tx.send(val); } Callback::NoRetry(tx) => { let _ = tx.send(val.map_err(|e| e.0)); } } } pub(crate) fn send_when( self, mut when: impl Future)>> + Unpin, ) -> impl Future { let mut cb = Some(self); // "select" on this callback being canceled, and the future completing future::poll_fn(move |cx| { match Pin::new(&mut when).poll(cx) { Poll::Ready(Ok(res)) => { cb.take().expect("polled after complete").send(Ok(res)); Poll::Ready(()) } Poll::Pending => { // check if the callback is canceled ready!(cb.as_mut().unwrap().poll_canceled(cx)); trace!("send_when canceled"); Poll::Ready(()) } Poll::Ready(Err(err)) => { cb.take().expect("polled after complete").send(Err(err)); Poll::Ready(()) } } }) } } #[cfg(test)] mod tests { #[cfg(feature = "nightly")] extern crate test; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use super::{channel, Callback, Receiver}; #[derive(Debug)] struct Custom(i32); impl Future for Receiver { type Output = Option<(T, Callback)>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.poll_next(cx) } } /// Helper to check if the future is ready after polling once. struct PollOnce<'a, F>(&'a mut F); impl Future for PollOnce<'_, F> where F: Future + Unpin, { type Output = Option<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match Pin::new(&mut self.0).poll(cx) { Poll::Ready(_) => Poll::Ready(Some(())), Poll::Pending => Poll::Ready(None), } } } #[tokio::test] async fn drop_receiver_sends_cancel_errors() { let _ = pretty_env_logger::try_init(); let (mut tx, mut rx) = channel::(); // must poll once for try_send to succeed assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); let promise = tx.try_send(Custom(43)).unwrap(); drop(rx); let fulfilled = promise.await; let err = fulfilled .expect("fulfilled") .expect_err("promise should error"); match (err.0.kind(), err.1) { (&crate::error::Kind::Canceled, Some(_)) => (), e => panic!("expected Error::Cancel(_), found {:?}", e), } } #[tokio::test] async fn sender_checks_for_want_on_send() { let (mut tx, mut rx) = channel::(); // one is allowed to buffer, second is rejected let _ = tx.try_send(Custom(1)).expect("1 buffered"); tx.try_send(Custom(2)).expect_err("2 not ready"); assert!(PollOnce(&mut rx).await.is_some(), "rx once"); // Even though 1 has been popped, only 1 could be buffered for the // lifetime of the channel. tx.try_send(Custom(2)).expect_err("2 still not ready"); assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); let _ = tx.try_send(Custom(2)).expect("2 ready"); } #[test] fn unbounded_sender_doesnt_bound_on_want() { let (tx, rx) = channel::(); let mut tx = tx.unbound(); let _ = tx.try_send(Custom(1)).unwrap(); let _ = tx.try_send(Custom(2)).unwrap(); let _ = tx.try_send(Custom(3)).unwrap(); drop(rx); let _ = tx.try_send(Custom(4)).unwrap_err(); } #[cfg(feature = "nightly")] #[bench] fn giver_queue_throughput(b: &mut test::Bencher) { use crate::{Body, Request, Response}; let mut rt = tokio::runtime::Builder::new() .enable_all() .basic_scheduler() .build() .unwrap(); let (mut tx, mut rx) = channel::, Response>(); b.iter(move || { let _ = tx.send(Request::default()).unwrap(); rt.block_on(async { loop { let poll_once = PollOnce(&mut rx); let opt = poll_once.await; if opt.is_none() { break; } } }); }) } #[cfg(feature = "nightly")] #[bench] fn giver_queue_not_ready(b: &mut test::Bencher) { let mut rt = tokio::runtime::Builder::new() .enable_all() .basic_scheduler() .build() .unwrap(); let (_tx, mut rx) = channel::(); b.iter(move || { rt.block_on(async { let poll_once = PollOnce(&mut rx); assert!(poll_once.await.is_none()); }); }) } #[cfg(feature = "nightly")] #[bench] fn giver_queue_cancel(b: &mut test::Bencher) { let (_tx, mut rx) = channel::(); b.iter(move || { rx.taker.cancel(); }) } }