diff options
Diffstat (limited to 'third_party/rust/hyper/src/proto/h2')
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/client.rs | 450 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/mod.rs | 471 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/ping.rs | 555 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/server.rs | 548 |
4 files changed, 2024 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/proto/h2/client.rs b/third_party/rust/hyper/src/proto/h2/client.rs new file mode 100644 index 0000000000..bac8eceb3a --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/client.rs @@ -0,0 +1,450 @@ +use std::error::Error as StdError; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use futures_channel::{mpsc, oneshot}; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use futures_util::stream::StreamExt as _; +use h2::client::{Builder, SendRequest}; +use h2::SendStream; +use http::{Method, StatusCode}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace, warn}; + +use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; +use crate::body::HttpBody; +use crate::client::dispatch::Callback; +use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; +use crate::ext::Protocol; +use crate::headers; +use crate::proto::h2::UpgradedSendStream; +use crate::proto::Dispatched; +use crate::upgrade::Upgraded; +use crate::{Body, Request, Response}; +use h2::client::ResponseFuture; + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +///// An mpsc channel is used to help notify the `Connection` task when *all* +///// other handles to it have been dropped, so that it can shutdown. +type ConnDropRef = mpsc::Sender<Never>; + +///// A oneshot channel watches the `Connection` task, and when it completes, +///// the "dispatch" task will be notified and can shutdown sooner. +type ConnEof = oneshot::Receiver<Never>; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024 * 5; // 5mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024 * 2; // 2mb +const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb +const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 1024; // 1mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_frame_size: u32, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_while_idle: bool, + pub(crate) max_concurrent_reset_streams: Option<usize>, + pub(crate) max_send_buffer_size: usize, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + #[cfg(feature = "runtime")] + keep_alive_while_idle: false, + max_concurrent_reset_streams: None, + max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE, + } + } +} + +fn new_builder(config: &Config) -> Builder { + let mut builder = Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .max_frame_size(config.max_frame_size) + .max_send_buffer_size(config.max_send_buffer_size) + .enable_push(false); + if let Some(max) = config.max_concurrent_reset_streams { + builder.max_concurrent_reset_streams(max); + } + builder +} + +fn new_ping_config(config: &Config) -> ping::Config { + ping::Config { + bdp_initial_window: if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + #[cfg(feature = "runtime")] + keep_alive_while_idle: config.keep_alive_while_idle, + } +} + +pub(crate) async fn handshake<T, B>( + io: T, + req_rx: ClientRx<B>, + config: &Config, + exec: Exec, +) -> crate::Result<ClientTask<B>> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: HttpBody, + B::Data: Send + 'static, +{ + let (h2_tx, mut conn) = new_builder(config) + .handshake::<_, SendBuf<B::Data>>(io) + .await + .map_err(crate::Error::new_h2)?; + + // An mpsc channel is used entirely to detect when the + // 'Client' has been dropped. This is to get around a bug + // in h2 where dropping all SendRequests won't notify a + // parked Connection. + let (conn_drop_ref, rx) = mpsc::channel(1); + let (cancel_tx, conn_eof) = oneshot::channel(); + + let conn_drop_rx = rx.into_future().map(|(item, _rx)| { + if let Some(never) = item { + match never {} + } + }); + + let ping_config = new_ping_config(&config); + + let (conn, ping) = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + let (recorder, mut ponger) = ping::channel(pp, ping_config); + + let conn = future::poll_fn(move |cx| { + match ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + conn.set_target_window_size(wnd); + conn.set_initial_window_size(wnd)?; + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} + } + + Pin::new(&mut conn).poll(cx) + }); + (Either::Left(conn), recorder) + } else { + (Either::Right(conn), ping::disabled()) + }; + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + + Ok(ClientTask { + ping, + conn_drop_ref, + conn_eof, + executor: exec, + h2_tx, + req_rx, + fut_ctx: None, + }) +} + +async fn conn_task<C, D>(conn: C, drop_rx: D, cancel_tx: oneshot::Sender<Never>) +where + C: Future + Unpin, + D: Future<Output = ()> + Unpin, +{ + match future::select(conn, drop_rx).await { + Either::Left(_) => { + // ok or err, the `conn` has finished + } + Either::Right(((), conn)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(cancel_tx); + let _ = conn.await; + } + } +} + +struct FutCtx<B> +where + B: HttpBody, +{ + is_connect: bool, + eos: bool, + fut: ResponseFuture, + body_tx: SendStream<SendBuf<B::Data>>, + body: B, + cb: Callback<Request<B>, Response<Body>>, +} + +impl<B: HttpBody> Unpin for FutCtx<B> {} + +pub(crate) struct ClientTask<B> +where + B: HttpBody, +{ + ping: ping::Recorder, + conn_drop_ref: ConnDropRef, + conn_eof: ConnEof, + executor: Exec, + h2_tx: SendRequest<SendBuf<B::Data>>, + req_rx: ClientRx<B>, + fut_ctx: Option<FutCtx<B>>, +} + +impl<B> ClientTask<B> +where + B: HttpBody + 'static, +{ + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.h2_tx.is_extended_connect_protocol_enabled() + } +} + +impl<B> ClientTask<B> +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut task::Context<'_>) { + let ping = self.ping.clone(); + let send_stream = if !f.is_connect { + if !f.eos { + let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| { + if let Err(e) = res { + debug!("client request body error: {}", e); + } + }); + + // eagerly see if the body pipe is ready and + // can thus skip allocating in the executor + match Pin::new(&mut pipe).poll(cx) { + Poll::Ready(_) => (), + Poll::Pending => { + let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); + let pipe = pipe.map(move |x| { + drop(conn_drop_ref); + drop(ping); + x + }); + // Clear send task + self.executor.execute(pipe); + } + } + } + + None + } else { + Some(f.body_tx) + }; + + let fut = f.fut.map(move |result| match result { + Ok(res) => { + // record that we got the response headers + ping.record_non_data(); + + let content_length = headers::content_length_parse_all(res.headers()); + if let (Some(mut send_stream), StatusCode::OK) = (send_stream, res.status()) { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect response with non-zero body not supported"); + + send_stream.send_reset(h2::Reason::INTERNAL_ERROR); + return Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + )); + } + let (parts, recv_stream) = res.into_parts(); + let mut res = Response::from_parts(parts, Body::empty()); + + let (pending, on_upgrade) = crate::upgrade::pending(); + let io = H2Upgraded { + ping, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + recv_stream, + buf: Bytes::new(), + }; + let upgraded = Upgraded::new(io, Bytes::new()); + + pending.fulfill(upgraded); + res.extensions_mut().insert(on_upgrade); + + Ok(res) + } else { + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length.into(), ping) + }); + Ok(res) + } + } + Err(err) => { + ping.ensure_not_timed_out().map_err(|e| (e, None))?; + + debug!("client response error: {}", err); + Err((crate::Error::new_h2(err), None)) + } + }); + self.executor.execute(f.cb.send_when(fut)); + } +} + +impl<B> Future for ClientTask<B> +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(self.h2_tx.poll_ready(cx)) { + Ok(()) => (), + Err(err) => { + self.ping.ensure_not_timed_out()?; + return if err.reason() == Some(::h2::Reason::NO_ERROR) { + trace!("connection gracefully shutdown"); + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Ready(Err(crate::Error::new_h2(err))) + }; + } + }; + + match self.fut_ctx.take() { + // If we were waiting on pending open + // continue where we left off. + Some(f) => { + self.poll_pipe(f, cx); + continue; + } + None => (), + } + + match self.req_rx.poll_recv(cx) { + Poll::Ready(Some((req, cb))) => { + // check that future hasn't been canceled already + if cb.is_canceled() { + trace!("request callback is canceled"); + continue; + } + let (head, body) = req.into_parts(); + let mut req = ::http::Request::from_parts(head, ()); + super::strip_connection_headers(req.headers_mut(), true); + if let Some(len) = body.size_hint().exact() { + if len != 0 || headers::method_has_defined_payload_semantics(req.method()) { + headers::set_content_length_if_missing(req.headers_mut(), len); + } + } + + let is_connect = req.method() == Method::CONNECT; + let eos = body.is_end_stream(); + + if is_connect { + if headers::content_length_parse_all(req.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 connect request with non-zero body not supported"); + cb.send(Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + ))); + continue; + } + } + + if let Some(protocol) = req.extensions_mut().remove::<Protocol>() { + req.extensions_mut().insert(protocol.into_inner()); + } + + let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) { + Ok(ok) => ok, + Err(err) => { + debug!("client send request error: {}", err); + cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + }; + + let f = FutCtx { + is_connect, + eos, + fut, + body_tx, + body, + cb, + }; + + // Check poll_ready() again. + // If the call to send_request() resulted in the new stream being pending open + // we have to wait for the open to complete before accepting new requests. + match self.h2_tx.poll_ready(cx) { + Poll::Pending => { + // Save Context + self.fut_ctx = Some(f); + return Poll::Pending; + } + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => { + f.cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + } + self.poll_pipe(f, cx); + continue; + } + + Poll::Ready(None) => { + trace!("client::dispatch::Sender dropped"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + + Poll::Pending => match ready!(Pin::new(&mut self.conn_eof).poll(cx)) { + Ok(never) => match never {}, + Err(_conn_is_eof) => { + trace!("connection task is closed, closing dispatch task"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }, + } + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/mod.rs b/third_party/rust/hyper/src/proto/h2/mod.rs new file mode 100644 index 0000000000..5857c919d1 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/mod.rs @@ -0,0 +1,471 @@ +use bytes::{Buf, Bytes}; +use h2::{Reason, RecvStream, SendStream}; +use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE}; +use http::HeaderMap; +use pin_project_lite::pin_project; +use std::error::Error as StdError; +use std::io::{self, Cursor, IoSlice}; +use std::mem; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::{debug, trace, warn}; + +use crate::body::HttpBody; +use crate::common::{task, Future, Pin, Poll}; +use crate::proto::h2::ping::Recorder; + +pub(crate) mod ping; + +cfg_client! { + pub(crate) mod client; + pub(crate) use self::client::ClientTask; +} + +cfg_server! { + pub(crate) mod server; + pub(crate) use self::server::Server; +} + +/// Default initial stream window size defined in HTTP2 spec. +pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535; + +fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { + // List of connection headers from: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + // + // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're + // tested separately. + let connection_headers = [ + HeaderName::from_lowercase(b"keep-alive").unwrap(), + HeaderName::from_lowercase(b"proxy-connection").unwrap(), + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + ]; + + for header in connection_headers.iter() { + if headers.remove(header).is_some() { + warn!("Connection header illegal in HTTP/2: {}", header.as_str()); + } + } + + if is_request { + if headers + .get(TE) + .map(|te_header| te_header != "trailers") + .unwrap_or(false) + { + warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests"); + headers.remove(TE); + } + } else if headers.remove(TE).is_some() { + warn!("TE headers illegal in HTTP/2 responses"); + } + + if let Some(header) = headers.remove(CONNECTION) { + warn!( + "Connection header illegal in HTTP/2: {}", + CONNECTION.as_str() + ); + let header_contents = header.to_str().unwrap(); + + // A `Connection` header may have a comma-separated list of names of other headers that + // are meant for only this specific connection. + // + // Iterate these names and remove them as headers. Connection-specific headers are + // forbidden in HTTP2, as that information has been moved into frame types of the h2 + // protocol. + for name in header_contents.split(',') { + let name = name.trim(); + headers.remove(name); + } + } +} + +// body adapters used by both Client and Server + +pin_project! { + struct PipeToSendStream<S> + where + S: HttpBody, + { + body_tx: SendStream<SendBuf<S::Data>>, + data_done: bool, + #[pin] + stream: S, + } +} + +impl<S> PipeToSendStream<S> +where + S: HttpBody, +{ + fn new(stream: S, tx: SendStream<SendBuf<S::Data>>) -> PipeToSendStream<S> { + PipeToSendStream { + body_tx: tx, + data_done: false, + stream, + } + } +} + +impl<S> Future for PipeToSendStream<S> +where + S: HttpBody, + S::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + if !*me.data_done { + // we don't have the next chunk of data yet, so just reserve 1 byte to make + // sure there's some capacity available. h2 will handle the capacity management + // for the actual body chunk. + me.body_tx.reserve_capacity(1); + + if me.body_tx.capacity() == 0 { + loop { + match ready!(me.body_tx.poll_capacity(cx)) { + Some(Ok(0)) => {} + Some(Ok(_)) => break, + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_body_write(e))) + } + None => { + // None means the stream is no longer in a + // streaming state, we either finished it + // somehow, or the remote reset us. + return Poll::Ready(Err(crate::Error::new_body_write( + "send stream capacity unexpectedly closed", + ))); + } + } + } + } else if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.as_mut().poll_data(cx)) { + Some(Ok(chunk)) => { + let is_eos = me.stream.is_end_stream(); + trace!( + "send body chunk: {} bytes, eos={}", + chunk.remaining(), + is_eos, + ); + + let buf = SendBuf::Buf(chunk); + me.body_tx + .send_data(buf, is_eos) + .map_err(crate::Error::new_body_write)?; + + if is_eos { + return Poll::Ready(Ok(())); + } + } + Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + None => { + me.body_tx.reserve_capacity(0); + let is_eos = me.stream.is_end_stream(); + if is_eos { + return Poll::Ready(me.body_tx.send_eos_frame()); + } else { + *me.data_done = true; + // loop again to poll_trailers + } + } + } + } else { + if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.poll_trailers(cx)) { + Ok(Some(trailers)) => { + me.body_tx + .send_trailers(trailers) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + Ok(None) => { + // There were no trailers, so send an empty DATA frame... + return Poll::Ready(me.body_tx.send_eos_frame()); + } + Err(e) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + } + } + } + } +} + +trait SendStreamExt { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>; + fn send_eos_frame(&mut self) -> crate::Result<()>; +} + +impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>, + { + let err = crate::Error::new_user_body(err); + debug!("send body user stream error: {}", err); + self.send_reset(err.h2_reason()); + err + } + + fn send_eos_frame(&mut self) -> crate::Result<()> { + trace!("send body eos"); + self.send_data(SendBuf::None, true) + .map_err(crate::Error::new_body_write) + } +} + +#[repr(usize)] +enum SendBuf<B> { + Buf(B), + Cursor(Cursor<Box<[u8]>>), + None, +} + +impl<B: Buf> Buf for SendBuf<B> { + #[inline] + fn remaining(&self) -> usize { + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => Buf::remaining(c), + Self::None => 0, + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {} + } + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded<B> +where + B: Buf, +{ + ping: Recorder, + send_stream: UpgradedSendStream<B>, + recv_stream: RecvStream, + buf: Bytes, +} + +impl<B> AsyncRead for H2Upgraded<B> +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(match e.reason() { + Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), + Some(Reason::STREAM_CLOSED) => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + } + _ => Err(h2_to_io_error(e)), + }) + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl<B> AsyncWrite for H2Upgraded<B> +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.send_stream.reserve_capacity(buf.len()); + + // We ignore all errors returned by `poll_capacity` and `write`, as we + // will get the correct from `poll_reset` anyway. + let cnt = match ready!(self.send_stream.poll_capacity(cx)) { + None => Some(0), + Some(Ok(cnt)) => self + .send_stream + .write(&buf[..cnt], false) + .ok() + .map(|()| cnt), + Some(Err(_)) => None, + }; + + if let Some(cnt) = cnt { + return Poll::Ready(Ok(cnt)); + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + if self.send_stream.write(&[], true).is_ok() { + return Poll::Ready(Ok(())) + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) => { + return Poll::Ready(Ok(())) + } + Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + +struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>); + +impl<B> UpgradedSendStream<B> +where + B: Buf, +{ + unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self { + assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>()); + Self(mem::transmute(inner)) + } + + fn reserve_capacity(&mut self, cnt: usize) { + unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } + } + + fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> { + unsafe { self.as_inner_unchecked().poll_capacity(cx) } + } + + fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> { + unsafe { self.as_inner_unchecked().poll_reset(cx) } + } + + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + unsafe { + self.as_inner_unchecked() + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } + } + + unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> { + &mut *(&mut self.0 as *mut _ as *mut _) + } +} + +#[repr(transparent)] +struct Neutered<B> { + _inner: B, + impossible: Impossible, +} + +enum Impossible {} + +unsafe impl<B> Send for Neutered<B> {} + +impl<B> Buf for Neutered<B> { + fn remaining(&self) -> usize { + match self.impossible {} + } + + fn chunk(&self) -> &[u8] { + match self.impossible {} + } + + fn advance(&mut self, _cnt: usize) { + match self.impossible {} + } +} diff --git a/third_party/rust/hyper/src/proto/h2/ping.rs b/third_party/rust/hyper/src/proto/h2/ping.rs new file mode 100644 index 0000000000..1e8386497c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/ping.rs @@ -0,0 +1,555 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. + +#[cfg(feature = "runtime")] +use std::fmt; +#[cfg(feature = "runtime")] +use std::future::Future; +#[cfg(feature = "runtime")] +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(not(feature = "runtime"))] +use std::time::Instant; + +use h2::{Ping, PingPong}; +#[cfg(feature = "runtime")] +use tokio::time::{Instant, Sleep}; +use tracing::{debug, trace}; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + rtt: 0.0, + ping_delay: Duration::from_millis(100), + stable_count: 0, + }); + + let (bytes, next_bdp_at) = if bdp.is_some() { + (Some(0), Some(Instant::now())) + } else { + (None, None) + }; + + #[cfg(feature = "runtime")] + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + timer: Box::pin(tokio::time::sleep(interval)), + state: KeepAliveState::Init, + }); + + #[cfg(feature = "runtime")] + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + #[cfg(feature = "runtime")] + last_read_at, + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + next_bdp_at, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + #[cfg(feature = "runtime")] + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option<WindowSize>, + /// If no frames are received in this amount of time, a PING frame is sent. + #[cfg(feature = "runtime")] + pub(super) keep_alive_interval: Option<Duration>, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + #[cfg(feature = "runtime")] + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + #[cfg(feature = "runtime")] + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option<Arc<Mutex<Shared>>>, +} + +pub(super) struct Ponger { + bdp: Option<Bdp>, + #[cfg(feature = "runtime")] + keep_alive: Option<KeepAlive>, + shared: Arc<Mutex<Shared>>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option<Instant>, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option<usize>, + /// We delay a variable amount of time between BDP pings. This allows us + /// to send less pings as the bandwidth stabilizes. + next_bdp_at: Option<Instant>, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + #[cfg(feature = "runtime")] + last_read_at: Option<Instant>, + + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Round trip time in seconds + rtt: f64, + /// Delay the next ping by this amount. + /// + /// This will change depending on how stable the current bandwidth is. + ping_delay: Duration, + /// The count of ping round trips where BDP has stayed the same. + stable_count: u32, +} + +#[cfg(feature = "runtime")] +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + + state: KeepAliveState, + timer: Pin<Box<Sleep>>, +} + +#[cfg(feature = "runtime")] +enum KeepAliveState { + Init, + Scheduled, + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + #[cfg(feature = "runtime")] + KeepAliveTimedOut, +} + +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + #[cfg(feature = "runtime")] + { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } + + #[cfg(not(feature = "runtime"))] + { + self.bdp_initial_window.is_some() + } + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + #[cfg(feature = "runtime")] + locked.update_last_read_at(); + + // are we ready to send another bdp ping? + // if not, we don't need to record bytes either + + if let Some(ref next_bdp_at) = locked.next_bdp_at { + if Instant::now() < *next_bdp_at { + return; + } else { + locked.next_bdp_at = None; + } + } + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + #[cfg(feature = "runtime")] + { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + #[cfg(feature = "client")] + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + #[cfg(feature = "runtime")] + { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<Ponged> { + let now = Instant::now(); + let mut locked = self.shared.lock().unwrap(); + #[cfg(feature = "runtime")] + let is_idle = self.is_idle(); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + ka.schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let start = locked + .ping_sent_at + .expect("pong received implies ping_sent_at"); + locked.ping_sent_at = None; + let rtt = now - start; + trace!("recv pong"); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.schedule(is_idle, &locked); + } + } + + if let Some(ref mut bdp) = self.bdp { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + + let update = bdp.calculate(bytes, rtt); + locked.next_bdp_at = Some(now + bdp.ping_delay); + if let Some(update) = update { + return Poll::Ready(Ponged::SizeUpdate(update)) + } + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + } + Poll::Pending => { + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + } + } + + // XXX: this doesn't register a waker...? + Poll::Pending + } + + #[cfg(feature = "runtime")] + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + #[cfg(feature = "runtime")] + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + #[cfg(feature = "runtime")] + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option<WindowSize> { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + self.stabilize_delay(); + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.rtt == 0.0 { + // First sample means rtt is first rtt. + self.rtt = rtt; + } else { + // Weigh this rtt as 1/8 for a moving average. + self.rtt += (rtt - self.rtt) * 0.125; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + self.stabilize_delay(); + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if bytes >= self.bdp as usize * 2 / 3 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + + self.stable_count = 0; + self.ping_delay /= 2; + Some(self.bdp) + } else { + self.stabilize_delay(); + None + } + } + + fn stabilize_delay(&mut self) { + if self.ping_delay < Duration::from_secs(10) { + self.stable_count += 1; + + if self.stable_count >= 2 { + self.ping_delay *= 4; + self.stable_count = 0; + } + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +#[cfg(feature = "runtime")] +impl KeepAlive { + fn schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.as_mut().reset(interval); + } + KeepAliveState::PingSent => { + if shared.is_ping_sent() { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.as_mut().reset(interval); + } + KeepAliveState::Scheduled => (), + } + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > self.timer.deadline() { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.as_mut().reset(timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +#[cfg(feature = "runtime")] +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +#[cfg(feature = "runtime")] +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +#[cfg(feature = "runtime")] +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/server.rs b/third_party/rust/hyper/src/proto/h2/server.rs new file mode 100644 index 0000000000..d24e6bac5f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/server.rs @@ -0,0 +1,548 @@ +use std::error::Error as StdError; +use std::marker::Unpin; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use h2::server::{Connection, Handshake, SendResponse}; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace, warn}; + +use super::{ping, PipeToSendStream, SendBuf}; +use crate::body::HttpBody; +use crate::common::exec::ConnStreamExec; +use crate::common::{date, task, Future, Pin, Poll}; +use crate::ext::Protocol; +use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; +use crate::proto::Dispatched; +use crate::service::HttpService; + +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; +use crate::{Body, Response}; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +// +// At the same time, a server more often has multiple clients connected, and +// so is more likely to use more resources than a client would. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb +const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 400; // 400kb +// 16 MB "sane default" taken from golang http2 +const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: u32 = 16 << 20; + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_frame_size: u32, + pub(crate) enable_connect_protocol: bool, + pub(crate) max_concurrent_streams: Option<u32>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + pub(crate) max_send_buffer_size: usize, + pub(crate) max_header_list_size: u32, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + enable_connect_protocol: false, + max_concurrent_streams: None, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE, + max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, + } + } +} + +pin_project! { + pub(crate) struct Server<T, S, B, E> + where + S: HttpService<Body>, + B: HttpBody, + { + exec: E, + service: S, + state: State<T, B>, + } +} + +enum State<T, B> +where + B: HttpBody, +{ + Handshaking { + ping_config: ping::Config, + hs: Handshake<T, SendBuf<B::Data>>, + }, + Serving(Serving<T, B>), + Closed, +} + +struct Serving<T, B> +where + B: HttpBody, +{ + ping: Option<(ping::Recorder, ping::Ponger)>, + conn: Connection<T, SendBuf<B::Data>>, + closing: Option<crate::Error>, +} + +impl<T, S, B, E> Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + E: ConnStreamExec<S::Future, B>, +{ + pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server<T, S, B, E> { + let mut builder = h2::server::Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .max_frame_size(config.max_frame_size) + .max_header_list_size(config.max_header_list_size) + .max_send_buffer_size(config.max_send_buffer_size); + if let Some(max) = config.max_concurrent_streams { + builder.max_concurrent_streams(max); + } + if config.enable_connect_protocol { + builder.enable_connect_protocol(); + } + let handshake = builder.handshake(io); + + let bdp = if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }; + + let ping_config = ping::Config { + bdp_initial_window: bdp, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + // If keep-alive is enabled for servers, always enabled while + // idle, so it can more aggressively close dead connections. + #[cfg(feature = "runtime")] + keep_alive_while_idle: true, + }; + + Server { + exec, + state: State::Handshaking { + ping_config, + hs: handshake, + }, + service, + } + } + + pub(crate) fn graceful_shutdown(&mut self) { + trace!("graceful_shutdown"); + match self.state { + State::Handshaking { .. } => { + // fall-through, to replace state with Closed + } + State::Serving(ref mut srv) => { + if srv.closing.is_none() { + srv.conn.graceful_shutdown(); + } + return; + } + State::Closed => { + return; + } + } + self.state = State::Closed; + } +} + +impl<T, S, B, E> Future for Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + E: ConnStreamExec<S::Future, B>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = &mut *self; + loop { + let next = match me.state { + State::Handshaking { + ref mut hs, + ref ping_config, + } => { + let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + Some(ping::channel(pp, ping_config.clone())) + } else { + None + }; + State::Serving(Serving { + ping, + conn, + closing: None, + }) + } + State::Serving(ref mut srv) => { + ready!(srv.poll_server(cx, &mut me.service, &mut me.exec))?; + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + State::Closed => { + // graceful_shutdown was called before handshaking finished, + // nothing to do here... + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }; + me.state = next; + } + } +} + +impl<T, B> Serving<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: HttpBody + 'static, +{ + fn poll_server<S, E>( + &mut self, + cx: &mut task::Context<'_>, + service: &mut S, + exec: &mut E, + ) -> Poll<crate::Result<()>> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, + { + if self.closing.is_none() { + loop { + self.poll_ping(cx); + + // Check that the service is ready to accept a new request. + // + // - If not, just drive the connection some. + // - If ready, try to accept a new request from the connection. + match service.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + // use `poll_closed` instead of `poll_accept`, + // in order to avoid accepting a request. + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => { + let err = crate::Error::new_user_service(err); + debug!("service closed: {}", err); + + let reason = err.h2_reason(); + if reason == Reason::NO_ERROR { + // NO_ERROR is only used for graceful shutdowns... + trace!("interpreting NO_ERROR user error as graceful_shutdown"); + self.conn.graceful_shutdown(); + } else { + trace!("abruptly shutting down with {:?}", reason); + self.conn.abrupt_shutdown(reason); + } + self.closing = Some(err); + break; + } + } + + // When the service is ready, accepts an incoming request. + match ready!(self.conn.poll_accept(cx)) { + Some(Ok((req, mut respond))) => { + trace!("incoming request"); + let content_length = headers::content_length_parse_all(req.headers()); + let ping = self + .ping + .as_ref() + .map(|ping| ping.0.clone()) + .unwrap_or_else(ping::disabled); + + // Record the headers received + ping.record_non_data(); + + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (mut req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length.into(), ping), + ), + None, + ) + } else { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect request with non-zero body not supported"); + respond.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Ok(())); + } + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::<OnUpgrade>().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some(ConnectParts { + pending, + ping, + recv_stream: stream, + }), + ) + }; + + if let Some(protocol) = req.extensions_mut().remove::<h2::ext::Protocol>() { + req.extensions_mut().insert(Protocol::from_inner(protocol)); + } + + let fut = H2Stream::new(service.call(req), connect_parts, respond); + exec.execute_h2stream(fut); + } + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + None => { + // no more incoming streams... + if let Some((ref ping, _)) = self.ping { + ping.ensure_not_timed_out()?; + } + + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + } + } + } + + debug_assert!( + self.closing.is_some(), + "poll_server broke loop without closing" + ); + + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + + Poll::Ready(Err(self.closing.take().expect("polled after error"))) + } + + fn poll_ping(&mut self, cx: &mut task::Context<'_>) { + if let Some((_, ref mut estimator)) = self.ping { + match estimator.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + self.conn.set_target_window_size(wnd); + let _ = self.conn.set_initial_window_size(wnd); + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("keep-alive timed out, closing connection"); + self.conn.abrupt_shutdown(h2::Reason::NO_ERROR); + } + Poll::Pending => {} + } + } + } +} + +pin_project! { + #[allow(missing_debug_implementations)] + pub struct H2Stream<F, B> + where + B: HttpBody, + { + reply: SendResponse<SendBuf<B::Data>>, + #[pin] + state: H2StreamState<F, B>, + } +} + +pin_project! { + #[project = H2StreamStateProj] + enum H2StreamState<F, B> + where + B: HttpBody, + { + Service { + #[pin] + fut: F, + connect_parts: Option<ConnectParts>, + }, + Body { + #[pin] + pipe: PipeToSendStream<B>, + }, + } +} + +struct ConnectParts { + pending: Pending, + ping: Recorder, + recv_stream: RecvStream, +} + +impl<F, B> H2Stream<F, B> +where + B: HttpBody, +{ + fn new( + fut: F, + connect_parts: Option<ConnectParts>, + respond: SendResponse<SendBuf<B::Data>>, + ) -> H2Stream<F, B> { + H2Stream { + reply: respond, + state: H2StreamState::Service { fut, connect_parts }, + } + } +} + +macro_rules! reply { + ($me:expr, $res:expr, $eos:expr) => {{ + match $me.reply.send_response($res, $eos) { + Ok(tx) => tx, + Err(e) => { + debug!("send response error: {}", e); + $me.reply.send_reset(Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + } + }}; +} + +impl<F, B, E> H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: HttpBody, + B::Data: 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + let mut me = self.project(); + loop { + let next = match me.state.as_mut().project() { + H2StreamStateProj::Service { + fut: h, + connect_parts, + } => { + let res = match h.poll(cx) { + Poll::Ready(Ok(r)) => r, + Poll::Pending => { + // Response is not yet ready, so we want to check if the client has sent a + // RST_STREAM frame which would cancel the current request. + if let Poll::Ready(reason) = + me.reply.poll_reset(cx).map_err(crate::Error::new_h2)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_h2(reason.into()))); + } + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let err = crate::Error::new_user_service(e); + warn!("http2 service errored: {}", err); + me.reply.send_reset(err.h2_reason()); + return Poll::Ready(Err(err)); + } + }; + + let (head, body) = res.into_parts(); + let mut res = ::http::Response::from_parts(head, ()); + super::strip_connection_headers(res.headers_mut(), false); + + // set Date header if it isn't already set... + res.headers_mut() + .entry(::http::header::DATE) + .or_insert_with(date::update_and_header_value); + + if let Some(connect_parts) = connect_parts.take() { + if res.status().is_success() { + if headers::content_length_parse_all(res.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 successful response to CONNECT request with body not supported"); + me.reply.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_user_header())); + } + let send_stream = reply!(me, res, false); + connect_parts.pending.fulfill(Upgraded::new( + H2Upgraded { + ping: connect_parts.ping, + recv_stream: connect_parts.recv_stream, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + } + + + if !body.is_end_stream() { + // automatically set Content-Length from body... + if let Some(len) = body.size_hint().exact() { + headers::set_content_length_if_missing(res.headers_mut(), len); + } + + let body_tx = reply!(me, res, false); + H2StreamState::Body { + pipe: PipeToSendStream::new(body, body_tx), + } + } else { + reply!(me, res, true); + return Poll::Ready(Ok(())); + } + } + H2StreamStateProj::Body { pipe } => { + return pipe.poll(cx); + } + }; + me.state.set(next); + } + } +} + +impl<F, B, E> Future for H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: HttpBody, + B::Data: 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll2(cx).map(|res| { + if let Err(e) = res { + debug!("stream error: {}", e); + } + }) + } +} |