diff options
Diffstat (limited to 'third_party/rust/hyper/src/common/drain.rs')
-rw-r--r-- | third_party/rust/hyper/src/common/drain.rs | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/common/drain.rs b/third_party/rust/hyper/src/common/drain.rs new file mode 100644 index 0000000000..174da876df --- /dev/null +++ b/third_party/rust/hyper/src/common/drain.rs @@ -0,0 +1,217 @@ +use std::mem; + +use pin_project_lite::pin_project; +use tokio::sync::watch; + +use super::{task, Future, Pin, Poll}; + +pub(crate) fn channel() -> (Signal, Watch) { + let (tx, rx) = watch::channel(()); + (Signal { tx }, Watch { rx }) +} + +pub(crate) struct Signal { + tx: watch::Sender<()>, +} + +pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>); + +#[derive(Clone)] +pub(crate) struct Watch { + rx: watch::Receiver<()>, +} + +pin_project! { + #[allow(missing_debug_implementations)] + pub struct Watching<F, FN> { + #[pin] + future: F, + state: State<FN>, + watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>, + _rx: watch::Receiver<()>, + } +} + +enum State<F> { + Watch(F), + Draining, +} + +impl Signal { + pub(crate) fn drain(self) -> Draining { + let _ = self.tx.send(()); + Draining(Box::pin(async move { self.tx.closed().await })) + } +} + +impl Future for Draining { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.as_mut().0).poll(cx) + } +} + +impl Watch { + pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> + where + F: Future, + FN: FnOnce(Pin<&mut F>), + { + let Self { mut rx } = self; + let _rx = rx.clone(); + Watching { + future, + state: State::Watch(on_drain), + watch: Box::pin(async move { + let _ = rx.changed().await; + }), + // Keep the receiver alive until the future completes, so that + // dropping it can signal that draining has completed. + _rx, + } + } +} + +impl<F, FN> Future for Watching<F, FN> +where + F: Future, + FN: FnOnce(Pin<&mut F>), +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + match mem::replace(me.state, State::Draining) { + State::Watch(on_drain) => { + match Pin::new(&mut me.watch).poll(cx) { + Poll::Ready(()) => { + // Drain has been triggered! + on_drain(me.future.as_mut()); + } + Poll::Pending => { + *me.state = State::Watch(on_drain); + return me.future.poll(cx); + } + } + } + State::Draining => return me.future.poll(cx), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestMe { + draining: bool, + finished: bool, + poll_cnt: usize, + } + + impl Future for TestMe { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_cnt += 1; + if self.finished { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + + #[test] + fn watch() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + let fut = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let mut watch = rx.watch(fut, |mut fut| { + fut.draining = true; + }); + + assert_eq!(watch.future.poll_cnt, 0); + + // First poll should poll the inner future + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 1); + + // Second poll should poll the inner future again + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 2); + + let mut draining = tx.drain(); + // Drain signaled, but needs another poll to be noticed. + assert!(!watch.future.draining); + assert_eq!(watch.future.poll_cnt, 2); + + // Now, poll after drain has been signaled. + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 3); + assert!(watch.future.draining); + + // Draining is not ready until watcher completes + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // Finishing up the watch future + watch.future.finished = true; + assert!(Pin::new(&mut watch).poll(cx).is_ready()); + assert_eq!(watch.future.poll_cnt, 4); + drop(watch); + + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }) + } + + #[test] + fn watch_clones() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + + let fut1 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + let fut2 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let watch1 = rx.clone().watch(fut1, |mut fut| { + fut.draining = true; + }); + let watch2 = rx.watch(fut2, |mut fut| { + fut.draining = true; + }); + + let mut draining = tx.drain(); + + // Still 2 outstanding watchers + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // drop 1 for whatever reason + drop(watch1); + + // Still not ready, 1 other watcher still pending + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + drop(watch2); + + // Now all watchers are gone, draining is complete + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }); + } +} |