use std::error::Error as StdError; use pin_project::{pin_project, project}; use tokio::io::{AsyncRead, AsyncWrite}; use super::conn::{SpawnAll, UpgradeableConnection, Watcher}; use super::Accept; use crate::body::{Body, Payload}; use crate::common::drain::{self, Draining, Signal, Watch, Watching}; use crate::common::exec::{H2Exec, NewSvcExec}; use crate::common::{task, Future, Pin, Poll, Unpin}; use crate::service::{HttpService, MakeServiceRef}; #[allow(missing_debug_implementations)] #[pin_project] pub struct Graceful { #[pin] state: State, } #[pin_project] pub(super) enum State { Running { drain: Option<(Signal, Watch)>, #[pin] spawn_all: SpawnAll, #[pin] signal: F, }, Draining(Draining), } impl Graceful { pub(super) fn new(spawn_all: SpawnAll, signal: F) -> Self { let drain = Some(drain::channel()); Graceful { state: State::Running { drain, spawn_all, signal, }, } } } impl Future for Graceful where I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef, S::Error: Into>, B: Payload, F: Future, E: H2Exec<>::Future, B>, E: NewSvcExec, { type Output = crate::Result<()>; #[project] fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { let mut me = self.project(); loop { let next = { #[project] match me.state.as_mut().project() { State::Running { drain, spawn_all, signal, } => match signal.poll(cx) { Poll::Ready(()) => { debug!("signal received, starting graceful shutdown"); let sig = drain.take().expect("drain channel").0; State::Draining(sig.drain()) } Poll::Pending => { let watch = drain.as_ref().expect("drain channel").1.clone(); return spawn_all.poll_watch(cx, &GracefulWatcher(watch)); } }, State::Draining(ref mut draining) => { return Pin::new(draining).poll(cx).map(Ok); } } }; me.state.set(next); } } } #[allow(missing_debug_implementations)] #[derive(Clone)] pub struct GracefulWatcher(Watch); impl Watcher for GracefulWatcher where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: HttpService, E: H2Exec, { type Future = Watching, fn(Pin<&mut UpgradeableConnection>)>; fn watch(&self, conn: UpgradeableConnection) -> Self::Future { self.0.clone().watch(conn, on_drain) } } fn on_drain(conn: Pin<&mut UpgradeableConnection>) where S: HttpService, S::Error: Into>, I: AsyncRead + AsyncWrite + Unpin, S::ResBody: Payload + 'static, E: H2Exec, { conn.graceful_shutdown() }