diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/hyper/src | |
parent | Initial commit. (diff) | |
download | firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.tar.xz firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/hyper/src')
53 files changed, 18488 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/body/aggregate.rs b/third_party/rust/hyper/src/body/aggregate.rs new file mode 100644 index 0000000000..97b6c2d91f --- /dev/null +++ b/third_party/rust/hyper/src/body/aggregate.rs @@ -0,0 +1,25 @@ +use bytes::Buf; + +use super::HttpBody; +use crate::common::buf::BufList; + +/// Aggregate the data buffers from a body asynchronously. +/// +/// The returned `impl Buf` groups the `Buf`s from the `HttpBody` without +/// copying them. This is ideal if you don't require a contiguous buffer. +pub async fn aggregate<T>(body: T) -> Result<impl Buf, T::Error> +where + T: HttpBody, +{ + let mut bufs = BufList::new(); + + futures_util::pin_mut!(body); + while let Some(buf) = body.data().await { + let buf = buf?; + if buf.has_remaining() { + bufs.push(buf); + } + } + + Ok(bufs) +} diff --git a/third_party/rust/hyper/src/body/body.rs b/third_party/rust/hyper/src/body/body.rs new file mode 100644 index 0000000000..315b8abb64 --- /dev/null +++ b/third_party/rust/hyper/src/body/body.rs @@ -0,0 +1,706 @@ +use std::borrow::Cow; +#[cfg(feature = "stream")] +use std::error::Error as StdError; +use std::fmt; + +use bytes::Bytes; +use futures_channel::{mpsc, oneshot}; +use futures_core::Stream; // for mpsc::Receiver +#[cfg(feature = "stream")] +use futures_util::TryStreamExt; +use http::HeaderMap; +use http_body::{Body as HttpBody, SizeHint}; + +use crate::common::{task, watch, Future, Never, Pin, Poll}; +use crate::proto::h2::ping; +use crate::proto::DecodedLength; +use crate::upgrade::OnUpgrade; + +type BodySender = mpsc::Sender<Result<Bytes, crate::Error>>; + +/// A stream of `Bytes`, used when receiving bodies. +/// +/// A good default [`HttpBody`](crate::body::HttpBody) to use in many +/// applications. +#[must_use = "streams do nothing unless polled"] +pub struct Body { + kind: Kind, + /// Keep the extra bits in an `Option<Box<Extra>>`, so that + /// Body stays small in the common case (no extras needed). + extra: Option<Box<Extra>>, +} + +enum Kind { + Once(Option<Bytes>), + Chan { + content_length: DecodedLength, + want_tx: watch::Sender, + rx: mpsc::Receiver<Result<Bytes, crate::Error>>, + }, + H2 { + ping: ping::Recorder, + content_length: DecodedLength, + recv: h2::RecvStream, + }, + // NOTE: This requires `Sync` because of how easy it is to use `await` + // while a borrow of a `Request<Body>` exists. + // + // See https://github.com/rust-lang/rust/issues/57017 + #[cfg(feature = "stream")] + Wrapped( + Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send + Sync>>, + ), +} + +struct Extra { + /// Allow the client to pass a future to delay the `Body` from returning + /// EOF. This allows the `Client` to try to put the idle connection + /// back into the pool before the body is "finished". + /// + /// The reason for this is so that creating a new request after finishing + /// streaming the body of a response could sometimes result in creating + /// a brand new connection, since the pool didn't know about the idle + /// connection yet. + delayed_eof: Option<DelayEof>, + on_upgrade: OnUpgrade, +} + +type DelayEofUntil = oneshot::Receiver<Never>; + +enum DelayEof { + /// Initial state, stream hasn't seen EOF yet. + NotEof(DelayEofUntil), + /// Transitions to this state once we've seen `poll` try to + /// return EOF (`None`). This future is then polled, and + /// when it completes, the Body finally returns EOF (`None`). + Eof(DelayEofUntil), +} + +/// A sender half used with `Body::channel()`. +/// +/// Useful when wanting to stream chunks from another thread. See +/// [`Body::channel`](Body::channel) for more. +#[must_use = "Sender does nothing unless sent on"] +pub struct Sender { + want_rx: watch::Receiver, + tx: BodySender, +} + +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + +impl Body { + /// Create an empty `Body` stream. + /// + /// # Example + /// + /// ``` + /// use hyper::{Body, Request}; + /// + /// // create a `GET /` request + /// let get = Request::new(Body::empty()); + /// ``` + #[inline] + pub fn empty() -> Body { + Body::new(Kind::Once(None)) + } + + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + #[inline] + pub fn channel() -> (Sender, Body) { + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) + } + + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) { + let (tx, rx) = mpsc::channel(0); + + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { want_rx, tx }; + let rx = Body::new(Kind::Chan { + content_length, + want_tx, + rx, + }); + + (tx, rx) + } + + /// Wrap a futures `Stream` in a box inside `Body`. + /// + /// # Example + /// + /// ``` + /// # use hyper::Body; + /// let chunks: Vec<Result<_, std::io::Error>> = vec![ + /// Ok("hello"), + /// Ok(" "), + /// Ok("world"), + /// ]; + /// + /// let stream = futures_util::stream::iter(chunks); + /// + /// let body = Body::wrap_stream(stream); + /// ``` + /// + /// # Optional + /// + /// This function requires enabling the `stream` feature in your + /// `Cargo.toml`. + #[cfg(feature = "stream")] + pub fn wrap_stream<S, O, E>(stream: S) -> Body + where + S: Stream<Item = Result<O, E>> + Send + Sync + 'static, + O: Into<Bytes> + 'static, + E: Into<Box<dyn StdError + Send + Sync>> + 'static, + { + let mapped = stream.map_ok(Into::into).map_err(Into::into); + Body::new(Kind::Wrapped(Box::pin(mapped))) + } + + /// Converts this `Body` into a `Future` of a pending HTTP upgrade. + /// + /// See [the `upgrade` module](crate::upgrade) for more. + pub fn on_upgrade(self) -> OnUpgrade { + self.extra + .map(|ex| ex.on_upgrade) + .unwrap_or_else(OnUpgrade::none) + } + + fn new(kind: Kind) -> Body { + Body { kind, extra: None } + } + + pub(crate) fn h2( + recv: h2::RecvStream, + content_length: DecodedLength, + ping: ping::Recorder, + ) -> Self { + let body = Body::new(Kind::H2 { + ping, + content_length, + recv, + }); + + body + } + + pub(crate) fn set_on_upgrade(&mut self, upgrade: OnUpgrade) { + debug_assert!(!upgrade.is_none(), "set_on_upgrade with empty upgrade"); + let extra = self.extra_mut(); + debug_assert!(extra.on_upgrade.is_none(), "set_on_upgrade twice"); + extra.on_upgrade = upgrade; + } + + pub(crate) fn delayed_eof(&mut self, fut: DelayEofUntil) { + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(fut)); + } + + fn take_delayed_eof(&mut self) -> Option<DelayEof> { + self.extra + .as_mut() + .and_then(|extra| extra.delayed_eof.take()) + } + + fn extra_mut(&mut self) -> &mut Extra { + self.extra.get_or_insert_with(|| { + Box::new(Extra { + delayed_eof: None, + on_upgrade: OnUpgrade::none(), + }) + }) + } + + fn poll_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<crate::Result<Bytes>>> { + match self.take_delayed_eof() { + Some(DelayEof::NotEof(mut delay)) => match self.poll_inner(cx) { + ok @ Poll::Ready(Some(Ok(..))) | ok @ Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(delay)); + ok + } + Poll::Ready(None) => match Pin::new(&mut delay).poll(cx) { + Poll::Ready(Ok(never)) => match never {}, + Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); + Poll::Pending + } + Poll::Ready(Err(_done)) => Poll::Ready(None), + }, + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + }, + Some(DelayEof::Eof(mut delay)) => match Pin::new(&mut delay).poll(cx) { + Poll::Ready(Ok(never)) => match never {}, + Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); + Poll::Pending + } + Poll::Ready(Err(_done)) => Poll::Ready(None), + }, + None => self.poll_inner(cx), + } + } + + fn poll_inner(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<crate::Result<Bytes>>> { + match self.kind { + Kind::Once(ref mut val) => Poll::Ready(val.take().map(Ok)), + Kind::Chan { + content_length: ref mut len, + ref mut rx, + ref mut want_tx, + } => { + want_tx.send(WANT_READY); + + match ready!(Pin::new(rx).poll_next(cx)?) { + Some(chunk) => { + len.sub_if(chunk.len() as u64); + Poll::Ready(Some(Ok(chunk))) + } + None => Poll::Ready(None), + } + } + Kind::H2 { + ref ping, + recv: ref mut h2, + content_length: ref mut len, + } => match ready!(h2.poll_data(cx)) { + Some(Ok(bytes)) => { + let _ = h2.flow_control().release_capacity(bytes.len()); + len.sub_if(bytes.len() as u64); + ping.record_data(bytes.len()); + Poll::Ready(Some(Ok(bytes))) + } + Some(Err(e)) => Poll::Ready(Some(Err(crate::Error::new_body(e)))), + None => Poll::Ready(None), + }, + + #[cfg(feature = "stream")] + Kind::Wrapped(ref mut s) => match ready!(s.as_mut().poll_next(cx)) { + Some(res) => Poll::Ready(Some(res.map_err(crate::Error::new_body))), + None => Poll::Ready(None), + }, + } + } + + pub(super) fn take_full_data(&mut self) -> Option<Bytes> { + if let Kind::Once(ref mut chunk) = self.kind { + chunk.take() + } else { + None + } + } +} + +impl Default for Body { + /// Returns [`Body::empty()`](Body::empty). + #[inline] + fn default() -> Body { + Body::empty() + } +} + +impl HttpBody for Body { + type Data = Bytes; + type Error = crate::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Data, Self::Error>>> { + self.poll_eof(cx) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { + match self.kind { + Kind::H2 { + recv: ref mut h2, + ref ping, + .. + } => match ready!(h2.poll_trailers(cx)) { + Ok(t) => { + ping.record_non_data(); + Poll::Ready(Ok(t)) + } + Err(e) => Poll::Ready(Err(crate::Error::new_h2(e))), + }, + _ => Poll::Ready(Ok(None)), + } + } + + fn is_end_stream(&self) -> bool { + match self.kind { + Kind::Once(ref val) => val.is_none(), + Kind::Chan { content_length, .. } => content_length == DecodedLength::ZERO, + Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(), + #[cfg(feature = "stream")] + Kind::Wrapped(..) => false, + } + } + + fn size_hint(&self) -> SizeHint { + match self.kind { + Kind::Once(Some(ref val)) => SizeHint::with_exact(val.len() as u64), + Kind::Once(None) => SizeHint::with_exact(0), + #[cfg(feature = "stream")] + Kind::Wrapped(..) => SizeHint::default(), + Kind::Chan { content_length, .. } | Kind::H2 { content_length, .. } => { + let mut hint = SizeHint::default(); + + if let Some(content_length) = content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + } + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Streaming; + #[derive(Debug)] + struct Empty; + #[derive(Debug)] + struct Full<'a>(&'a Bytes); + + let mut builder = f.debug_tuple("Body"); + match self.kind { + Kind::Once(None) => builder.field(&Empty), + Kind::Once(Some(ref chunk)) => builder.field(&Full(chunk)), + _ => builder.field(&Streaming), + }; + + builder.finish() + } +} + +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +impl Stream for Body { + type Item = crate::Result<Bytes>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { + HttpBody::poll_data(self, cx) + } +} + +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +impl From<Box<dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send + Sync>> + for Body +{ + #[inline] + fn from( + stream: Box< + dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send + Sync, + >, + ) -> Body { + Body::new(Kind::Wrapped(stream.into())) + } +} + +impl From<Bytes> for Body { + #[inline] + fn from(chunk: Bytes) -> Body { + if chunk.is_empty() { + Body::empty() + } else { + Body::new(Kind::Once(Some(chunk))) + } + } +} + +impl From<Vec<u8>> for Body { + #[inline] + fn from(vec: Vec<u8>) -> Body { + Body::from(Bytes::from(vec)) + } +} + +impl From<&'static [u8]> for Body { + #[inline] + fn from(slice: &'static [u8]) -> Body { + Body::from(Bytes::from(slice)) + } +} + +impl From<Cow<'static, [u8]>> for Body { + #[inline] + fn from(cow: Cow<'static, [u8]>) -> Body { + match cow { + Cow::Borrowed(b) => Body::from(b), + Cow::Owned(o) => Body::from(o), + } + } +} + +impl From<String> for Body { + #[inline] + fn from(s: String) -> Body { + Body::from(Bytes::from(s.into_bytes())) + } +} + +impl From<&'static str> for Body { + #[inline] + fn from(slice: &'static str) -> Body { + Body::from(Bytes::from(slice.as_bytes())) + } +} + +impl From<Cow<'static, str>> for Body { + #[inline] + fn from(cow: Cow<'static, str>) -> Body { + match cow { + Cow::Borrowed(b) => Body::from(b), + Cow::Owned(o) => Body::from(o), + } + } +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self.tx + .poll_ready(cx) + .map_err(|_| crate::Error::new_closed()) + } + + fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> crate::Result<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on this channel when it is ready. + pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> { + self.ready().await?; + self.tx + .try_send(Ok(chunk)) + .map_err(|_| crate::Error::new_closed()) + } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + pub fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self.tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + /// Aborts the body in an abnormal fashion. + pub fn abort(self) { + let _ = self + .tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(crate::Error::new_body_write_aborted())); + } + + pub(crate) fn send_error(&mut self, err: crate::Error) { + let _ = self.tx.try_send(Err(err)); + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Open; + #[derive(Debug)] + struct Closed; + + let mut builder = f.debug_tuple("Sender"); + match self.want_rx.peek() { + watch::CLOSED => builder.field(&Closed), + _ => builder.field(&Open), + }; + + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use std::mem; + use std::task::Poll; + + use super::{Body, DecodedLength, HttpBody, Sender, SizeHint}; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + + let body_size = mem::size_of::<Body>(); + let body_expected_size = mem::size_of::<u64>() * 6; + assert!( + body_size <= body_expected_size, + "Body size = {} <= {}", + body_size, + body_expected_size, + ); + + assert_eq!(body_size, mem::size_of::<Option<Body>>(), "Option<Body>"); + + assert_eq!( + mem::size_of::<Sender>(), + mem::size_of::<usize>() * 4, + "Sender" + ); + + assert_eq!( + mem::size_of::<Sender>(), + mem::size_of::<Option<Sender>>(), + "Option<Sender>" + ); + } + + #[test] + fn size_hint() { + fn eq(body: Body, b: SizeHint, note: &str) { + let a = body.size_hint(); + assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); + assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); + } + + eq(Body::from("Hello"), SizeHint::with_exact(5), "from str"); + + eq(Body::empty(), SizeHint::with_exact(0), "empty"); + + eq(Body::channel().1, SizeHint::new(), "channel"); + + eq( + Body::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, + SizeHint::with_exact(4), + "channel with length", + ); + } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = Body::channel(); + + tx.abort(); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.data().await.expect("item 1").expect("chunk 1"); + assert_eq!(chunk1, "chunk 1"); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = Body::channel(); + + assert!(rx.data().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.data()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!( + tx_ready.poll().is_ready(), + "tx is ready after rx has been polled" + ); + } + + #[test] + fn channel_notices_closure() { + let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(ref e)) if e.is_closed() => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } +} diff --git a/third_party/rust/hyper/src/body/mod.rs b/third_party/rust/hyper/src/body/mod.rs new file mode 100644 index 0000000000..4693cbc8d1 --- /dev/null +++ b/third_party/rust/hyper/src/body/mod.rs @@ -0,0 +1,64 @@ +//! Streaming bodies for Requests and Responses +//! +//! For both [Clients](crate::client) and [Servers](crate::server), requests and +//! responses use streaming bodies, instead of complete buffering. This +//! allows applications to not use memory they don't need, and allows exerting +//! back-pressure on connections by only reading when asked. +//! +//! There are two pieces to this in hyper: +//! +//! - **The [`HttpBody`](HttpBody) trait** describes all possible bodies. +//! hyper allows any body type that implements `HttpBody`, allowing +//! applications to have fine-grained control over their streaming. +//! - **The [`Body`](Body) concrete type**, which is an implementation of +//! `HttpBody`, and returned by hyper as a "receive stream" (so, for server +//! requests and client responses). It is also a decent default implementation +//! if you don't have very custom needs of your send streams. + +pub use bytes::{Buf, Bytes}; +pub use http_body::Body as HttpBody; + +pub use self::aggregate::aggregate; +pub use self::body::{Body, Sender}; +pub use self::to_bytes::to_bytes; + +pub(crate) use self::payload::Payload; + +mod aggregate; +mod body; +mod payload; +mod to_bytes; + +/// An optimization to try to take a full body if immediately available. +/// +/// This is currently limited to *only* `hyper::Body`s. +pub(crate) fn take_full_data<T: Payload + 'static>(body: &mut T) -> Option<T::Data> { + use std::any::{Any, TypeId}; + + // This static type check can be optimized at compile-time. + if TypeId::of::<T>() == TypeId::of::<Body>() { + let mut full = (body as &mut dyn Any) + .downcast_mut::<Body>() + .expect("must be Body") + .take_full_data(); + // This second cast is required to make the type system happy. + // Without it, the compiler cannot reason that the type is actually + // `T::Data`. Oh wells. + // + // It's still a measurable win! + (&mut full as &mut dyn Any) + .downcast_mut::<Option<T::Data>>() + .expect("must be T::Data") + .take() + } else { + None + } +} + +fn _assert_send_sync() { + fn _assert_send<T: Send>() {} + fn _assert_sync<T: Sync>() {} + + _assert_send::<Body>(); + _assert_sync::<Body>(); +} diff --git a/third_party/rust/hyper/src/body/payload.rs b/third_party/rust/hyper/src/body/payload.rs new file mode 100644 index 0000000000..f24adad175 --- /dev/null +++ b/third_party/rust/hyper/src/body/payload.rs @@ -0,0 +1,139 @@ +use std::error::Error as StdError; + +use bytes::Buf; +use http::HeaderMap; + +use crate::common::{task, Pin, Poll}; +use http_body::{Body as HttpBody, SizeHint}; + +/// This trait represents a streaming body of a `Request` or `Response`. +/// +/// The built-in implementation of this trait is [`Body`](::Body), in case you +/// don't need to customize a send stream for your own application. +pub trait Payload: sealed::Sealed + Send + 'static { + /// A buffer of bytes representing a single chunk of a body. + type Data: Buf + Send; + + /// The error type of this stream. + type Error: Into<Box<dyn StdError + Send + Sync>>; + + /// Poll for a `Data` buffer. + /// + /// Similar to `Stream::poll_next`, this yields `Some(Data)` until + /// the body ends, when it yields `None`. + fn poll_data( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Data, Self::Error>>>; + + /// Poll for an optional **single** `HeaderMap` of trailers. + /// + /// This should **only** be called after `poll_data` has ended. + /// + /// Note: Trailers aren't currently used for HTTP/1, only for HTTP/2. + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { + Poll::Ready(Ok(None)) + } + + /// A hint that the `Body` is complete, and doesn't need to be polled more. + /// + /// This can be useful to determine if the there is any body or trailers + /// without having to poll. An empty `Body` could return `true` and hyper + /// would be able to know that only the headers need to be sent. Or, it can + /// also be checked after each `poll_data` call, to allow hyper to try to + /// end the underlying stream with the last chunk, instead of needing to + /// send an extra `DATA` frame just to mark the stream as finished. + /// + /// As a hint, it is used to try to optimize, and thus is OK for a default + /// implementation to return `false`. + fn is_end_stream(&self) -> bool { + false + } + + /// Returns a `SizeHint` providing an upper and lower bound on the possible size. + /// + /// If there is an exact size of bytes known, this would allow hyper to + /// send a `Content-Length` header automatically, not needing to fall back to + /// `TransferEncoding: chunked`. + /// + /// This does not need to be kept updated after polls, it will only be + /// called once to create the headers. + fn size_hint(&self) -> SizeHint { + SizeHint::default() + } +} + +impl<T> Payload for T +where + T: HttpBody + Send + 'static, + T::Data: Send, + T::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Data = T::Data; + type Error = T::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Data, Self::Error>>> { + HttpBody::poll_data(self, cx) + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { + HttpBody::poll_trailers(self, cx) + } + + fn is_end_stream(&self) -> bool { + HttpBody::is_end_stream(self) + } + + fn size_hint(&self) -> SizeHint { + HttpBody::size_hint(self) + } +} + +impl<T> sealed::Sealed for T +where + T: HttpBody + Send + 'static, + T::Data: Send, + T::Error: Into<Box<dyn StdError + Send + Sync>>, +{ +} + +mod sealed { + pub trait Sealed {} +} + +/* +impl<E: Payload> Payload for Box<E> { + type Data = E::Data; + type Error = E::Error; + + fn poll_data(&mut self) -> Poll<Option<Self::Data>, Self::Error> { + (**self).poll_data() + } + + fn poll_trailers(&mut self) -> Poll<Option<HeaderMap>, Self::Error> { + (**self).poll_trailers() + } + + fn is_end_stream(&self) -> bool { + (**self).is_end_stream() + } + + fn content_length(&self) -> Option<u64> { + (**self).content_length() + } + + #[doc(hidden)] + fn __hyper_full_data(&mut self, arg: FullDataArg) -> FullDataRet<Self::Data> { + (**self).__hyper_full_data(arg) + } +} +*/ diff --git a/third_party/rust/hyper/src/body/to_bytes.rs b/third_party/rust/hyper/src/body/to_bytes.rs new file mode 100644 index 0000000000..4cce7857d7 --- /dev/null +++ b/third_party/rust/hyper/src/body/to_bytes.rs @@ -0,0 +1,40 @@ +use bytes::{Buf, BufMut, Bytes}; + +use super::HttpBody; + +/// Concatenate the buffers from a body into a single `Bytes` asynchronously. +/// +/// This may require copying the data into a single buffer. If you don't need +/// a contiguous buffer, prefer the [`aggregate`](crate::body::aggregate()) +/// function. +pub async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> +where + T: HttpBody, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.to_bytes()); + }; + + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} diff --git a/third_party/rust/hyper/src/client/conn.rs b/third_party/rust/hyper/src/client/conn.rs new file mode 100644 index 0000000000..81eaf8287f --- /dev/null +++ b/third_party/rust/hyper/src/client/conn.rs @@ -0,0 +1,698 @@ +//! Lower-level client connection API. +//! +//! The types in this module are to provide a lower-level API based around a +//! single connection. Connecting to a host, pooling connections, and the like +//! are not handled at this level. This module provides the building blocks to +//! customize those things externally. +//! +//! If don't have need to manage connections yourself, consider using the +//! higher-level [Client](super) API. + +use std::fmt; +use std::mem; +use std::sync::Arc; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use futures_util::future::{self, Either, FutureExt as _}; +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower_service::Service; + +use super::dispatch; +use crate::body::Payload; +use crate::common::{task, BoxSendFuture, Exec, Executor, Future, Pin, Poll}; +use crate::proto; +use crate::upgrade::Upgraded; +use crate::{Body, Request, Response}; + +type Http1Dispatcher<T, B, R> = proto::dispatch::Dispatcher<proto::dispatch::Client<B>, B, T, R>; + +#[pin_project] +enum ProtoClient<T, B> +where + B: Payload, +{ + H1(#[pin] Http1Dispatcher<T, B, proto::h1::ClientTransaction>), + H2(#[pin] proto::h2::ClientTask<B>), +} + +/// Returns a handshake future over some IO. +/// +/// This is a shortcut for `Builder::new().handshake(io)`. +pub async fn handshake<T>( + io: T, +) -> crate::Result<(SendRequest<crate::Body>, Connection<T, crate::Body>)> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + Builder::new().handshake(io).await +} + +/// The sender side of an established connection. +pub struct SendRequest<B> { + dispatch: dispatch::Sender<Request<B>, Response<Body>>, +} + +/// A future that processes all HTTP state for the IO object. +/// +/// In most cases, this should just be spawned into an executor, so that it +/// can process incoming and outgoing messages, notice hangups, and the like. +#[must_use = "futures do nothing unless polled"] +pub struct Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: Payload + 'static, +{ + inner: Option<ProtoClient<T, B>>, +} + +/// A builder to configure an HTTP connection. +/// +/// After setting options, the builder is used to create a handshake future. +#[derive(Clone, Debug)] +pub struct Builder { + pub(super) exec: Exec, + h1_writev: bool, + h1_title_case_headers: bool, + h1_read_buf_exact_size: Option<usize>, + h1_max_buf_size: Option<usize>, + http2: bool, + h2_builder: proto::h2::client::Config, +} + +/// A future returned by `SendRequest::send_request`. +/// +/// Yields a `Response` if successful. +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + inner: ResponseFutureState, +} + +enum ResponseFutureState { + Waiting(dispatch::Promise<Response<Body>>), + // Option is to be able to `take()` it in `poll` + Error(Option<crate::Error>), +} + +/// Deconstructed parts of a `Connection`. +/// +/// This allows taking apart a `Connection` at a later time, in order to +/// reclaim the IO object, and additional related pieces. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used in the handshake. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +// ========== internal client api + +// A `SendRequest` that can be cloned to send HTTP2 requests. +// private for now, probably not a great idea of a type... +#[must_use = "futures do nothing unless polled"] +pub(super) struct Http2SendRequest<B> { + dispatch: dispatch::UnboundedSender<Request<B>, Response<Body>>, +} + +// ===== impl SendRequest + +impl<B> SendRequest<B> { + /// Polls to determine whether this sender can be used yet for a request. + /// + /// If the associated connection is closed, this returns an Error. + pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.dispatch.poll_ready(cx) + } + + pub(super) fn when_ready(self) -> impl Future<Output = crate::Result<Self>> { + let mut me = Some(self); + future::poll_fn(move |cx| { + ready!(me.as_mut().unwrap().poll_ready(cx))?; + Poll::Ready(Ok(me.take().unwrap())) + }) + } + + pub(super) fn is_ready(&self) -> bool { + self.dispatch.is_ready() + } + + pub(super) fn is_closed(&self) -> bool { + self.dispatch.is_closed() + } + + pub(super) fn into_http2(self) -> Http2SendRequest<B> { + Http2SendRequest { + dispatch: self.dispatch.unbound(), + } + } +} + +impl<B> SendRequest<B> +where + B: Payload + 'static, +{ + /// Sends a `Request` on the associated connection. + /// + /// Returns a future that if successful, yields the `Response`. + /// + /// # Note + /// + /// There are some key differences in what automatic things the `Client` + /// does for you that will not be done here: + /// + /// - `Client` requires absolute-form `Uri`s, since the scheme and + /// authority are needed to connect. They aren't required here. + /// - Since the `Client` requires absolute-form `Uri`s, it can add + /// the `Host` header based on it. You must add a `Host` header yourself + /// before calling this method. + /// - Since absolute-form `Uri`s are not required, if received, they will + /// be serialized as-is. + /// + /// # Example + /// + /// ``` + /// # use http::header::HOST; + /// # use hyper::client::conn::SendRequest; + /// # use hyper::Body; + /// use hyper::Request; + /// + /// # async fn doc(mut tx: SendRequest<Body>) -> hyper::Result<()> { + /// // build a Request + /// let req = Request::builder() + /// .uri("/foo/bar") + /// .header(HOST, "hyper.rs") + /// .body(Body::empty()) + /// .unwrap(); + /// + /// // send it and await a Response + /// let res = tx.send_request(req).await?; + /// // assert the Response + /// assert!(res.status().is_success()); + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn send_request(&mut self, req: Request<B>) -> ResponseFuture { + let inner = match self.dispatch.send(req) { + Ok(rx) => ResponseFutureState::Waiting(rx), + Err(_req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + ResponseFutureState::Error(Some(err)) + } + }; + + ResponseFuture { inner } + } + + pub(crate) fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + Unpin + where + B: Send, + { + match self.dispatch.try_send(req) { + Ok(rx) => { + Either::Left(rx.then(move |res| { + match res { + Ok(Ok(res)) => future::ok(res), + Ok(Err(err)) => future::err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + } + Err(req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + Either::Right(future::err((err, Some(req)))) + } + } + } +} + +impl<B> Service<Request<B>> for SendRequest<B> +where + B: Payload + 'static, +{ + type Response = Response<Body>; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + self.poll_ready(cx) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.send_request(req) + } +} + +impl<B> fmt::Debug for SendRequest<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendRequest").finish() + } +} + +// ===== impl Http2SendRequest + +impl<B> Http2SendRequest<B> { + pub(super) fn is_ready(&self) -> bool { + self.dispatch.is_ready() + } + + pub(super) fn is_closed(&self) -> bool { + self.dispatch.is_closed() + } +} + +impl<B> Http2SendRequest<B> +where + B: Payload + 'static, +{ + pub(super) fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + where + B: Send, + { + match self.dispatch.try_send(req) { + Ok(rx) => { + Either::Left(rx.then(move |res| { + match res { + Ok(Ok(res)) => future::ok(res), + Ok(Err(err)) => future::err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + } + Err(req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + Either::Right(future::err((err, Some(req)))) + } + } + } +} + +impl<B> fmt::Debug for Http2SendRequest<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Http2SendRequest").finish() + } +} + +impl<B> Clone for Http2SendRequest<B> { + fn clone(&self) -> Self { + Http2SendRequest { + dispatch: self.dispatch.clone(), + } + } +} + +// ===== impl Connection + +impl<T, B> Connection<T, B> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: Payload + Unpin + 'static, +{ + /// Return the inner IO object, and additional information. + /// + /// Only works for HTTP/1 connections. HTTP/2 connections will panic. + pub fn into_parts(self) -> Parts<T> { + let (io, read_buf, _) = match self.inner.expect("already upgraded") { + ProtoClient::H1(h1) => h1.into_inner(), + ProtoClient::H2(_h2) => { + panic!("http2 cannot into_inner"); + } + }; + + Parts { + io, + read_buf, + _inner: (), + } + } + + /// Poll the connection for completion, but without calling `shutdown` + /// on the underlying IO. + /// + /// This is useful to allow running a connection while doing an HTTP + /// upgrade. Once the upgrade is completed, the connection would be "done", + /// but it is not desired to actually shutdown the IO object. Instead you + /// would take it back using `into_parts`. + /// + /// Use [`poll_fn`](https://docs.rs/futures/0.1.25/futures/future/fn.poll_fn.html) + /// and [`try_ready!`](https://docs.rs/futures/0.1.25/futures/macro.try_ready.html) + /// to work with this function; or use the `without_shutdown` wrapper. + pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match *self.inner.as_mut().expect("already upgraded") { + ProtoClient::H1(ref mut h1) => h1.poll_without_shutdown(cx), + ProtoClient::H2(ref mut h2) => Pin::new(h2).poll(cx).map_ok(|_| ()), + } + } + + /// Prevent shutdown of the underlying IO object at the end of service the request, + /// instead run `into_parts`. This is a convenience wrapper over `poll_without_shutdown`. + pub fn without_shutdown(self) -> impl Future<Output = crate::Result<Parts<T>>> { + let mut conn = Some(self); + future::poll_fn(move |cx| -> Poll<crate::Result<Parts<T>>> { + ready!(conn.as_mut().unwrap().poll_without_shutdown(cx))?; + Poll::Ready(Ok(conn.take().unwrap().into_parts())) + }) + } +} + +impl<T, B> Future for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: Payload + 'static, +{ + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match ready!(Pin::new(self.inner.as_mut().unwrap()).poll(cx))? { + proto::Dispatched::Shutdown => Poll::Ready(Ok(())), + proto::Dispatched::Upgrade(pending) => { + let h1 = match mem::replace(&mut self.inner, None) { + Some(ProtoClient::H1(h1)) => h1, + _ => unreachable!("Upgrade expects h1"), + }; + + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(io, buf)); + Poll::Ready(Ok(())) + } + } + } +} + +impl<T, B> fmt::Debug for Connection<T, B> +where + T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + B: Payload + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").finish() + } +} + +// ===== impl Builder + +impl Builder { + /// Creates a new connection builder. + #[inline] + pub fn new() -> Builder { + Builder { + exec: Exec::Default, + h1_writev: true, + h1_read_buf_exact_size: None, + h1_title_case_headers: false, + h1_max_buf_size: None, + http2: false, + h2_builder: Default::default(), + } + } + + /// Provide an executor to execute background HTTP2 tasks. + pub fn executor<E>(&mut self, exec: E) -> &mut Builder + where + E: Executor<BoxSendFuture> + Send + Sync + 'static, + { + self.exec = Exec::Executor(Arc::new(exec)); + self + } + + pub(super) fn h1_writev(&mut self, enabled: bool) -> &mut Builder { + self.h1_writev = enabled; + self + } + + pub(super) fn h1_title_case_headers(&mut self, enabled: bool) -> &mut Builder { + self.h1_title_case_headers = enabled; + self + } + + pub(super) fn h1_read_buf_exact_size(&mut self, sz: Option<usize>) -> &mut Builder { + self.h1_read_buf_exact_size = sz; + self.h1_max_buf_size = None; + self + } + + pub(super) fn h1_max_buf_size(&mut self, max: usize) -> &mut Self { + assert!( + max >= proto::h1::MINIMUM_MAX_BUFFER_SIZE, + "the max_buf_size cannot be smaller than the minimum that h1 specifies." + ); + + self.h1_max_buf_size = Some(max); + self.h1_read_buf_exact_size = None; + self + } + + /// Sets whether HTTP2 is required. + /// + /// Default is false. + pub fn http2_only(&mut self, enabled: bool) -> &mut Builder { + self.http2 = enabled; + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_stream_window_size = sz; + } + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_conn_window_size = sz; + } + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + use proto::h2::SPEC_WINDOW_SIZE; + + self.h2_builder.adaptive_window = enabled; + if enabled { + self.h2_builder.initial_conn_window_size = SPEC_WINDOW_SIZE; + self.h2_builder.initial_stream_window_size = SPEC_WINDOW_SIZE; + } + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.h2_builder.keep_alive_while_idle = enabled; + self + } + + /// Constructs a connection with the configured options and IO. + pub fn handshake<T, B>( + &self, + io: T, + ) -> impl Future<Output = crate::Result<(SendRequest<B>, Connection<T, B>)>> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: Payload + 'static, + { + let opts = self.clone(); + + async move { + trace!("client handshake HTTP/{}", if opts.http2 { 2 } else { 1 }); + + let (tx, rx) = dispatch::channel(); + let proto = if !opts.http2 { + let mut conn = proto::Conn::new(io); + if !opts.h1_writev { + conn.set_write_strategy_flatten(); + } + if opts.h1_title_case_headers { + conn.set_title_case_headers(); + } + if let Some(sz) = opts.h1_read_buf_exact_size { + conn.set_read_buf_exact_size(sz); + } + if let Some(max) = opts.h1_max_buf_size { + conn.set_max_buf_size(max); + } + let cd = proto::h1::dispatch::Client::new(rx); + let dispatch = proto::h1::Dispatcher::new(cd, conn); + ProtoClient::H1(dispatch) + } else { + let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone()) + .await?; + ProtoClient::H2(h2) + }; + + Ok(( + SendRequest { dispatch: tx }, + Connection { inner: Some(proto) }, + )) + } + } +} + +// ===== impl ResponseFuture + +impl Future for ResponseFuture { + type Output = crate::Result<Response<Body>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.inner { + ResponseFutureState::Waiting(ref mut rx) => { + Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(resp)) => Ok(resp), + Ok(Err(err)) => Err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_canceled) => panic!("dispatch dropped without returning error"), + }) + } + ResponseFutureState::Error(ref mut err) => { + Poll::Ready(Err(err.take().expect("polled after ready"))) + } + } + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + +// ===== impl ProtoClient + +impl<T, B> Future for ProtoClient<T, B> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: Payload + 'static, +{ + type Output = crate::Result<proto::Dispatched>; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + #[project] + match self.project() { + ProtoClient::H1(c) => c.poll(cx), + ProtoClient::H2(c) => c.poll(cx), + } + } +} + +// assert trait markers + +trait AssertSend: Send {} +trait AssertSendSync: Send + Sync {} + +#[doc(hidden)] +impl<B: Send> AssertSendSync for SendRequest<B> {} + +#[doc(hidden)] +impl<T: Send, B: Send> AssertSend for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: Payload + 'static, +{ +} + +#[doc(hidden)] +impl<T: Send + Sync, B: Send + Sync> AssertSendSync for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: Payload + 'static, + B::Data: Send + Sync + 'static, +{ +} + +#[doc(hidden)] +impl AssertSendSync for Builder {} + +#[doc(hidden)] +impl AssertSend for ResponseFuture {} diff --git a/third_party/rust/hyper/src/client/connect/dns.rs b/third_party/rust/hyper/src/client/connect/dns.rs new file mode 100644 index 0000000000..acffb8b9e5 --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/dns.rs @@ -0,0 +1,404 @@ +//! DNS Resolution used by the `HttpConnector`. +//! +//! This module contains: +//! +//! - A [`GaiResolver`](GaiResolver) that is the default resolver for the +//! `HttpConnector`. +//! - The `Name` type used as an argument to custom resolvers. +//! +//! # Resolvers are `Service`s +//! +//! A resolver is just a +//! `Service<Name, Response = impl Iterator<Item = IpAddr>>`. +//! +//! A simple resolver that ignores the name and always returns a specific +//! address: +//! +//! ```rust,ignore +//! use std::{convert::Infallible, iter, net::IpAddr}; +//! +//! let resolver = tower::service_fn(|_name| async { +//! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1]))) +//! }); +//! ``` +use std::error::Error; +use std::future::Future; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{self, Poll}; +use std::{fmt, io, vec}; + +use tokio::task::JoinHandle; +use tower_service::Service; + +pub(super) use self::sealed::Resolve; + +/// A domain name to resolve into IP addresses. +#[derive(Clone, Hash, Eq, PartialEq)] +pub struct Name { + host: String, +} + +/// A resolver using blocking `getaddrinfo` calls in a threadpool. +#[derive(Clone)] +pub struct GaiResolver { + _priv: (), +} + +/// An iterator of IP addresses returned from `getaddrinfo`. +pub struct GaiAddrs { + inner: IpAddrs, +} + +/// A future to resolve a name returned by `GaiResolver`. +pub struct GaiFuture { + inner: JoinHandle<Result<IpAddrs, io::Error>>, +} + +impl Name { + pub(super) fn new(host: String) -> Name { + Name { host } + } + + /// View the hostname as a string slice. + pub fn as_str(&self) -> &str { + &self.host + } +} + +impl fmt::Debug for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.host, f) + } +} + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.host, f) + } +} + +impl FromStr for Name { + type Err = InvalidNameError; + + fn from_str(host: &str) -> Result<Self, Self::Err> { + // Possibly add validation later + Ok(Name::new(host.to_owned())) + } +} + +/// Error indicating a given string was not a valid domain name. +#[derive(Debug)] +pub struct InvalidNameError(()); + +impl fmt::Display for InvalidNameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Not a valid domain name") + } +} + +impl Error for InvalidNameError {} + +impl GaiResolver { + /// Construct a new `GaiResolver`. + pub fn new() -> Self { + GaiResolver { _priv: () } + } +} + +impl Service<Name> for GaiResolver { + type Response = GaiAddrs; + type Error = io::Error; + type Future = GaiFuture; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + let blocking = tokio::task::spawn_blocking(move || { + debug!("resolving host={:?}", name.host); + (&*name.host, 0) + .to_socket_addrs() + .map(|i| IpAddrs { iter: i }) + }); + + GaiFuture { inner: blocking } + } +} + +impl fmt::Debug for GaiResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiResolver") + } +} + +impl Future for GaiFuture { + type Output = Result<GaiAddrs, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.inner).poll(cx).map(|res| match res { + Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }), + Ok(Err(err)) => Err(err), + Err(join_err) => panic!("gai background task failed: {:?}", join_err), + }) + } +} + +impl fmt::Debug for GaiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiFuture") + } +} + +impl Iterator for GaiAddrs { + type Item = IpAddr; + + fn next(&mut self) -> Option<Self::Item> { + self.inner.next().map(|sa| sa.ip()) + } +} + +impl fmt::Debug for GaiAddrs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiAddrs") + } +} + +pub(super) struct IpAddrs { + iter: vec::IntoIter<SocketAddr>, +} + +impl IpAddrs { + pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { + IpAddrs { + iter: addrs.into_iter(), + } + } + + pub(super) fn try_parse(host: &str, port: u16) -> Option<IpAddrs> { + if let Ok(addr) = host.parse::<Ipv4Addr>() { + let addr = SocketAddrV4::new(addr, port); + return Some(IpAddrs { + iter: vec![SocketAddr::V4(addr)].into_iter(), + }); + } + let host = host.trim_start_matches('[').trim_end_matches(']'); + if let Ok(addr) = host.parse::<Ipv6Addr>() { + let addr = SocketAddrV6::new(addr, port, 0, 0); + return Some(IpAddrs { + iter: vec![SocketAddr::V6(addr)].into_iter(), + }); + } + None + } + + pub(super) fn split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs) { + if let Some(local_addr) = local_addr { + let preferred = self + .iter + .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6()) + .collect(); + + (IpAddrs::new(preferred), IpAddrs::new(vec![])) + } else { + let preferring_v6 = self + .iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self + .iter + .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); + + (IpAddrs::new(preferred), IpAddrs::new(fallback)) + } + } + + pub(super) fn is_empty(&self) -> bool { + self.iter.as_slice().is_empty() + } + + pub(super) fn len(&self) -> usize { + self.iter.as_slice().len() + } +} + +impl Iterator for IpAddrs { + type Item = SocketAddr; + #[inline] + fn next(&mut self) -> Option<SocketAddr> { + self.iter.next() + } +} + +/* +/// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API. +/// +/// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the +/// multi-threaded Tokio runtime. +#[cfg(feature = "runtime")] +#[derive(Clone, Debug)] +pub struct TokioThreadpoolGaiResolver(()); + +/// The future returned by `TokioThreadpoolGaiResolver`. +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub struct TokioThreadpoolGaiFuture { + name: Name, +} + +#[cfg(feature = "runtime")] +impl TokioThreadpoolGaiResolver { + /// Creates a new DNS resolver that will use tokio threadpool's blocking + /// feature. + /// + /// **Requires** its futures to be run on the threadpool runtime. + pub fn new() -> Self { + TokioThreadpoolGaiResolver(()) + } +} + +#[cfg(feature = "runtime")] +impl Service<Name> for TokioThreadpoolGaiResolver { + type Response = GaiAddrs; + type Error = io::Error; + type Future = TokioThreadpoolGaiFuture; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + TokioThreadpoolGaiFuture { name } + } +} + +#[cfg(feature = "runtime")] +impl Future for TokioThreadpoolGaiFuture { + type Output = Result<GaiAddrs, io::Error>; + + fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match ready!(tokio_executor::threadpool::blocking(|| ( + self.name.as_str(), + 0 + ) + .to_socket_addrs())) + { + Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { + inner: IpAddrs { iter }, + })), + Ok(Err(e)) => Poll::Ready(Err(e)), + // a BlockingError, meaning not on a tokio_executor::threadpool :( + Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), + } + } +} +*/ + +mod sealed { + use super::{IpAddr, Name}; + use crate::common::{task, Future, Poll}; + use tower_service::Service; + + // "Trait alias" for `Service<Name, Response = Addrs>` + pub trait Resolve { + type Addrs: Iterator<Item = IpAddr>; + type Error: Into<Box<dyn std::error::Error + Send + Sync>>; + type Future: Future<Output = Result<Self::Addrs, Self::Error>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + fn resolve(&mut self, name: Name) -> Self::Future; + } + + impl<S> Resolve for S + where + S: Service<Name>, + S::Response: Iterator<Item = IpAddr>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + type Addrs = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Service::poll_ready(self, cx) + } + + fn resolve(&mut self, name: Name) -> Self::Future { + Service::call(self, name) + } + } +} + +pub(crate) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> +where + R: Resolve, +{ + futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; + resolver.resolve(name).await +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_ip_addrs_split_by_preference() { + let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into(); + let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into(); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(v4_addr.ip())); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.is_empty()); + + let (mut preferred, fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(v6_addr.ip())); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.is_empty()); + } + + #[test] + fn test_name_from_str() { + const DOMAIN: &str = "test.example.com"; + let name = Name::from_str(DOMAIN).expect("Should be a valid domain"); + assert_eq!(name.as_str(), DOMAIN); + assert_eq!(name.to_string(), DOMAIN); + } + + #[test] + fn ip_addrs_try_parse_v6() { + let dst = ::http::Uri::from_static("http://[::1]:8080/"); + + let mut addrs = + IpAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port")) + .expect("try_parse"); + + let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected"); + + assert_eq!(addrs.next(), Some(expected)); + } +} diff --git a/third_party/rust/hyper/src/client/connect/http.rs b/third_party/rust/hyper/src/client/connect/http.rs new file mode 100644 index 0000000000..86137d7f1c --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/http.rs @@ -0,0 +1,859 @@ +use std::error::Error as StdError; +use std::fmt; +use std::future::Future; +use std::io; +use std::marker::PhantomData; +use std::net::{IpAddr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; + +use futures_util::future::Either; +use http::uri::{Scheme, Uri}; +use net2::TcpBuilder; +use pin_project::pin_project; +use tokio::net::TcpStream; +use tokio::time::Delay; + +use super::dns::{self, resolve, GaiResolver, Resolve}; +use super::{Connected, Connection}; +//#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver; + +/// A connector for the `http` scheme. +/// +/// Performs DNS resolution in a thread pool, and then connects over TCP. +/// +/// # Note +/// +/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes +/// transport information such as the remote socket address used. +#[derive(Clone)] +pub struct HttpConnector<R = GaiResolver> { + config: Arc<Config>, + resolver: R, +} + +/// Extra information about the transport when an HttpConnector is used. +/// +/// # Example +/// +/// ``` +/// # async fn doc() -> hyper::Result<()> { +/// use hyper::Uri; +/// use hyper::client::{Client, connect::HttpInfo}; +/// +/// let client = Client::new(); +/// let uri = Uri::from_static("http://example.com"); +/// +/// let res = client.get(uri).await?; +/// res +/// .extensions() +/// .get::<HttpInfo>() +/// .map(|info| { +/// println!("remote addr = {}", info.remote_addr()); +/// }); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Note +/// +/// If a different connector is used besides [`HttpConnector`](HttpConnector), +/// this value will not exist in the extensions. Consult that specific +/// connector to see what "extra" information it might provide to responses. +#[derive(Clone, Debug)] +pub struct HttpInfo { + remote_addr: SocketAddr, +} + +#[derive(Clone)] +struct Config { + connect_timeout: Option<Duration>, + enforce_http: bool, + happy_eyeballs_timeout: Option<Duration>, + keep_alive_timeout: Option<Duration>, + local_address: Option<IpAddr>, + nodelay: bool, + reuse_address: bool, + send_buffer_size: Option<usize>, + recv_buffer_size: Option<usize>, +} + +// ===== impl HttpConnector ===== + +impl HttpConnector { + /// Construct a new HttpConnector. + pub fn new() -> HttpConnector { + HttpConnector::new_with_resolver(GaiResolver::new()) + } +} + +/* +#[cfg(feature = "runtime")] +impl HttpConnector<TokioThreadpoolGaiResolver> { + /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`. + /// + /// This resolver **requires** the threadpool runtime to be used. + pub fn new_with_tokio_threadpool_resolver() -> Self { + HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new()) + } +} +*/ + +impl<R> HttpConnector<R> { + /// Construct a new HttpConnector. + /// + /// Takes a `Resolve` to handle DNS lookups. + pub fn new_with_resolver(resolver: R) -> HttpConnector<R> { + HttpConnector { + config: Arc::new(Config { + connect_timeout: None, + enforce_http: true, + happy_eyeballs_timeout: Some(Duration::from_millis(300)), + keep_alive_timeout: None, + local_address: None, + nodelay: false, + reuse_address: false, + send_buffer_size: None, + recv_buffer_size: None, + }), + resolver, + } + } + + /// Option to enforce all `Uri`s have the `http` scheme. + /// + /// Enabled by default. + #[inline] + pub fn enforce_http(&mut self, is_enforced: bool) { + self.config_mut().enforce_http = is_enforced; + } + + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. + /// + /// If `None`, the option will not be set. + /// + /// Default is `None`. + #[inline] + pub fn set_keepalive(&mut self, dur: Option<Duration>) { + self.config_mut().keep_alive_timeout = dur; + } + + /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. + /// + /// Default is `false`. + #[inline] + pub fn set_nodelay(&mut self, nodelay: bool) { + self.config_mut().nodelay = nodelay; + } + + /// Sets the value of the SO_SNDBUF option on the socket. + #[inline] + pub fn set_send_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().send_buffer_size = size; + } + + /// Sets the value of the SO_RCVBUF option on the socket. + #[inline] + pub fn set_recv_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().recv_buffer_size = size; + } + + /// Set that all sockets are bound to the configured address before connection. + /// + /// If `None`, the sockets will not be bound. + /// + /// Default is `None`. + #[inline] + pub fn set_local_address(&mut self, addr: Option<IpAddr>) { + self.config_mut().local_address = addr; + } + + /// Set the connect timeout. + /// + /// If a domain resolves to multiple IP addresses, the timeout will be + /// evenly divided across them. + /// + /// Default is `None`. + #[inline] + pub fn set_connect_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().connect_timeout = dur; + } + + /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm. + /// + /// If hostname resolves to both IPv4 and IPv6 addresses and connection + /// cannot be established using preferred address family before timeout + /// elapses, then connector will in parallel attempt connection using other + /// address family. + /// + /// If `None`, parallel connection attempts are disabled. + /// + /// Default is 300 milliseconds. + /// + /// [RFC 6555]: https://tools.ietf.org/html/rfc6555 + #[inline] + pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().happy_eyeballs_timeout = dur; + } + + /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + /// + /// Default is `false`. + #[inline] + pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self { + self.config_mut().reuse_address = reuse_address; + self + } + + // private + + fn config_mut(&mut self) -> &mut Config { + // If the are HttpConnector clones, this will clone the inner + // config. So mutating the config won't ever affect previous + // clones. + Arc::make_mut(&mut self.config) + } +} + +static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; +static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; +static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; + +// R: Debug required for now to allow adding it to debug output later... +impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpConnector").finish() + } +} + +impl<R> tower_service::Service<Uri> for HttpConnector<R> +where + R: Resolve + Clone + Send + Sync + 'static, + R::Future: Send, +{ + type Response = TcpStream; + type Error = ConnectError; + type Future = HttpConnecting<R>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let mut self_ = self.clone(); + HttpConnecting { + fut: Box::pin(async move { self_.call_async(dst).await }), + _marker: PhantomData, + } + } +} + +impl<R> HttpConnector<R> +where + R: Resolve, +{ + async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> { + trace!( + "Http::connect; scheme={:?}, host={:?}, port={:?}", + dst.scheme(), + dst.host(), + dst.port(), + ); + + if self.config.enforce_http { + if dst.scheme() != Some(&Scheme::HTTP) { + return Err(ConnectError { + msg: INVALID_NOT_HTTP.into(), + cause: None, + }); + } + } else if dst.scheme().is_none() { + return Err(ConnectError { + msg: INVALID_MISSING_SCHEME.into(), + cause: None, + }); + } + + let host = match dst.host() { + Some(s) => s, + None => { + return Err(ConnectError { + msg: INVALID_MISSING_HOST.into(), + cause: None, + }) + } + }; + let port = match dst.port() { + Some(port) => port.as_u16(), + None => { + if dst.scheme() == Some(&Scheme::HTTPS) { + 443 + } else { + 80 + } + } + }; + + let config = &self.config; + + // If the host is already an IP addr (v4 or v6), + // skip resolving the dns and start connecting right away. + let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) { + addrs + } else { + let addrs = resolve(&mut self.resolver, dns::Name::new(host.into())) + .await + .map_err(ConnectError::dns)?; + let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect(); + dns::IpAddrs::new(addrs) + }; + + let c = ConnectingTcp::new( + config.local_address, + addrs, + config.connect_timeout, + config.happy_eyeballs_timeout, + config.reuse_address, + ); + + let sock = c + .connect() + .await + .map_err(ConnectError::m("tcp connect error"))?; + + if let Some(dur) = config.keep_alive_timeout { + sock.set_keepalive(Some(dur)) + .map_err(ConnectError::m("tcp set_keepalive error"))?; + } + + if let Some(size) = config.send_buffer_size { + sock.set_send_buffer_size(size) + .map_err(ConnectError::m("tcp set_send_buffer_size error"))?; + } + + if let Some(size) = config.recv_buffer_size { + sock.set_recv_buffer_size(size) + .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; + } + + sock.set_nodelay(config.nodelay) + .map_err(ConnectError::m("tcp set_nodelay error"))?; + + Ok(sock) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + let connected = Connected::new(); + if let Ok(remote_addr) = self.peer_addr() { + connected.extra(HttpInfo { remote_addr }) + } else { + connected + } + } +} + +impl HttpInfo { + /// Get the remote address of the transport used. + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } +} + +// Not publicly exported (so missing_docs doesn't trigger). +// +// We return this `Future` instead of the `Pin<Box<dyn Future>>` directly +// so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot +// (and thus we can change the type in the future). +#[must_use = "futures do nothing unless polled"] +#[pin_project] +#[allow(missing_debug_implementations)] +pub struct HttpConnecting<R> { + #[pin] + fut: BoxConnecting, + _marker: PhantomData<R>, +} + +type ConnectResult = Result<TcpStream, ConnectError>; +type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>; + +impl<R: Resolve> Future for HttpConnecting<R> { + type Output = ConnectResult; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} + +// Not publicly exported (so missing_docs doesn't trigger). +pub struct ConnectError { + msg: Box<str>, + cause: Option<Box<dyn StdError + Send + Sync>>, +} + +impl ConnectError { + fn new<S, E>(msg: S, cause: E) -> ConnectError + where + S: Into<Box<str>>, + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError { + msg: msg.into(), + cause: Some(cause.into()), + } + } + + fn dns<E>(cause: E) -> ConnectError + where + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError::new("dns error", cause) + } + + fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError + where + S: Into<Box<str>>, + E: Into<Box<dyn StdError + Send + Sync>>, + { + move |cause| ConnectError::new(msg, cause) + } +} + +impl fmt::Debug for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref cause) = self.cause { + f.debug_tuple("ConnectError") + .field(&self.msg) + .field(cause) + .finish() + } else { + self.msg.fmt(f) + } + } +} + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.msg)?; + + if let Some(ref cause) = self.cause { + write!(f, ": {}", cause)?; + } + + Ok(()) + } +} + +impl StdError for ConnectError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause.as_ref().map(|e| &**e as _) + } +} + +struct ConnectingTcp { + local_addr: Option<IpAddr>, + preferred: ConnectingTcpRemote, + fallback: Option<ConnectingTcpFallback>, + reuse_address: bool, +} + +impl ConnectingTcp { + fn new( + local_addr: Option<IpAddr>, + remote_addrs: dns::IpAddrs, + connect_timeout: Option<Duration>, + fallback_timeout: Option<Duration>, + reuse_address: bool, + ) -> ConnectingTcp { + if let Some(fallback_timeout) = fallback_timeout { + let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(local_addr); + if fallback_addrs.is_empty() { + return ConnectingTcp { + local_addr, + preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), + fallback: None, + reuse_address, + }; + } + + ConnectingTcp { + local_addr, + preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), + fallback: Some(ConnectingTcpFallback { + delay: tokio::time::delay_for(fallback_timeout), + remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout), + }), + reuse_address, + } + } else { + ConnectingTcp { + local_addr, + preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout), + fallback: None, + reuse_address, + } + } + } +} + +struct ConnectingTcpFallback { + delay: Delay, + remote: ConnectingTcpRemote, +} + +struct ConnectingTcpRemote { + addrs: dns::IpAddrs, + connect_timeout: Option<Duration>, +} + +impl ConnectingTcpRemote { + fn new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self { + let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32)); + + Self { + addrs, + connect_timeout, + } + } +} + +impl ConnectingTcpRemote { + async fn connect( + &mut self, + local_addr: &Option<IpAddr>, + reuse_address: bool, + ) -> io::Result<TcpStream> { + let mut err = None; + for addr in &mut self.addrs { + debug!("connecting to {}", addr); + match connect(&addr, local_addr, reuse_address, self.connect_timeout)?.await { + Ok(tcp) => { + debug!("connected to {}", addr); + return Ok(tcp); + } + Err(e) => { + trace!("connect error for {}: {:?}", addr, e); + err = Some(e); + } + } + } + + Err(err.take().expect("missing connect error")) + } +} + +fn connect( + addr: &SocketAddr, + local_addr: &Option<IpAddr>, + reuse_address: bool, + connect_timeout: Option<Duration>, +) -> io::Result<impl Future<Output = io::Result<TcpStream>>> { + let builder = match *addr { + SocketAddr::V4(_) => TcpBuilder::new_v4()?, + SocketAddr::V6(_) => TcpBuilder::new_v6()?, + }; + + if reuse_address { + builder.reuse_address(reuse_address)?; + } + + if let Some(ref local_addr) = *local_addr { + // Caller has requested this socket be bound before calling connect + builder.bind(SocketAddr::new(local_addr.clone(), 0))?; + } else if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + builder.bind(any)?; + } + + let addr = *addr; + + let std_tcp = builder.to_tcp_stream()?; + + Ok(async move { + let connect = TcpStream::connect_std(std_tcp, &addr); + match connect_timeout { + Some(dur) => match tokio::time::timeout(dur, connect).await { + Ok(Ok(s)) => Ok(s), + Ok(Err(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), + }, + None => connect.await, + } + }) +} + +impl ConnectingTcp { + async fn connect(mut self) -> io::Result<TcpStream> { + let Self { + ref local_addr, + reuse_address, + .. + } = self; + match self.fallback { + None => self.preferred.connect(local_addr, reuse_address).await, + Some(mut fallback) => { + let preferred_fut = self.preferred.connect(local_addr, reuse_address); + futures_util::pin_mut!(preferred_fut); + + let fallback_fut = fallback.remote.connect(local_addr, reuse_address); + futures_util::pin_mut!(fallback_fut); + + let (result, future) = + match futures_util::future::select(preferred_fut, fallback.delay).await { + Either::Left((result, _fallback_delay)) => { + (result, Either::Right(fallback_fut)) + } + Either::Right(((), preferred_fut)) => { + // Delay is done, start polling both the preferred and the fallback + futures_util::future::select(preferred_fut, fallback_fut) + .await + .factor_first() + } + }; + + if result.is_err() { + // Fallback to the remaining future (could be preferred or fallback) + // if we get an error + future.await + } else { + result + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use ::http::Uri; + + use super::super::sealed::{Connect, ConnectSvc}; + use super::HttpConnector; + + async fn connect<C>( + connector: C, + dst: Uri, + ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> + where + C: Connect, + { + connector.connect(super::super::sealed::Internal, dst).await + } + + #[tokio::test] + async fn test_errors_enforce_http() { + let dst = "https://example.domain/foo/bar?baz".parse().unwrap(); + let connector = HttpConnector::new(); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); + } + + #[tokio::test] + async fn test_errors_missing_scheme() { + let dst = "example.domain".parse().unwrap(); + let mut connector = HttpConnector::new(); + connector.enforce_http(false); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); + } + + #[test] + #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] + fn client_happy_eyeballs() { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener}; + use std::time::{Duration, Instant}; + + use super::dns; + use super::ConnectingTcp; + + let _ = pretty_env_logger::try_init(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server4.local_addr().unwrap(); + let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap(); + let mut rt = tokio::runtime::Builder::new() + .enable_io() + .enable_time() + .basic_scheduler() + .build() + .unwrap(); + + let local_timeout = Duration::default(); + let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1; + let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1; + let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout) + + Duration::from_millis(250); + + let scenarios = &[ + // Fast primary, without fallback. + (&[local_ipv4_addr()][..], 4, local_timeout, false), + (&[local_ipv6_addr()][..], 6, local_timeout, false), + // Fast primary, with (unused) fallback. + ( + &[local_ipv4_addr(), local_ipv6_addr()][..], + 4, + local_timeout, + false, + ), + ( + &[local_ipv6_addr(), local_ipv4_addr()][..], + 6, + local_timeout, + false, + ), + // Unreachable + fast primary, without fallback. + ( + &[unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + unreachable_v6_timeout, + false, + ), + // Unreachable + fast primary, with (unused) fallback. + ( + &[ + unreachable_ipv4_addr(), + local_ipv4_addr(), + local_ipv6_addr(), + ][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[ + unreachable_ipv6_addr(), + local_ipv6_addr(), + local_ipv4_addr(), + ][..], + 6, + unreachable_v6_timeout, + true, + ), + // Slow primary, with (used) fallback. + ( + &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout, + false, + ), + ( + &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout, + true, + ), + // Slow primary, with (used) unreachable + fast fallback. + ( + &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout + unreachable_v6_timeout, + false, + ), + ( + &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout + unreachable_v4_timeout, + true, + ), + ]; + + // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network. + // Otherwise, connection to "slow" IPv6 address will error-out immediately. + let ipv6_accessible = measure_connect(slow_ipv6_addr()).0; + + for &(hosts, family, timeout, needs_ipv6_access) in scenarios { + if needs_ipv6_access && !ipv6_accessible { + continue; + } + + let (start, stream) = rt + .block_on(async move { + let addrs = hosts + .iter() + .map(|host| (host.clone(), addr.port()).into()) + .collect(); + let connecting_tcp = ConnectingTcp::new( + None, + dns::IpAddrs::new(addrs), + None, + Some(fallback_timeout), + false, + ); + let start = Instant::now(); + Ok::<_, io::Error>((start, connecting_tcp.connect().await?)) + }) + .unwrap(); + let res = if stream.peer_addr().unwrap().is_ipv4() { + 4 + } else { + 6 + }; + let duration = start.elapsed(); + + // Allow actual duration to be +/- 150ms off. + let min_duration = if timeout >= Duration::from_millis(150) { + timeout - Duration::from_millis(150) + } else { + Duration::default() + }; + let max_duration = timeout + Duration::from_millis(150); + + assert_eq!(res, family); + assert!(duration >= min_duration); + assert!(duration <= max_duration); + } + + fn local_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 1).into() + } + + fn local_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into() + } + + fn unreachable_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 2).into() + } + + fn unreachable_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into() + } + + fn slow_ipv4_addr() -> IpAddr { + // RFC 6890 reserved IPv4 address. + Ipv4Addr::new(198, 18, 0, 25).into() + } + + fn slow_ipv6_addr() -> IpAddr { + // RFC 6890 reserved IPv6 address. + Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into() + } + + fn measure_connect(addr: IpAddr) -> (bool, Duration) { + let start = Instant::now(); + let result = + std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1)); + + let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut; + let duration = start.elapsed(); + (reachable, duration) + } + } +} diff --git a/third_party/rust/hyper/src/client/connect/mod.rs b/third_party/rust/hyper/src/client/connect/mod.rs new file mode 100644 index 0000000000..01e2dcf9d3 --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/mod.rs @@ -0,0 +1,381 @@ +//! Connectors used by the `Client`. +//! +//! This module contains: +//! +//! - A default [`HttpConnector`][] that does DNS resolution and establishes +//! connections over TCP. +//! - Types to build custom connectors. +//! +//! # Connectors +//! +//! A "connector" is a [`Service`][] that takes a [`Uri`][] destination, and +//! its `Response` is some type implementing [`AsyncRead`][], [`AsyncWrite`][], +//! and [`Connection`][]. +//! +//! ## Custom Connectors +//! +//! A simple connector that ignores the `Uri` destination and always returns +//! a TCP connection to the same address could be written like this: +//! +//! ```rust,ignore +//! let connector = tower::service_fn(|_dst| async { +//! tokio::net::TcpStream::connect("127.0.0.1:1337") +//! }) +//! ``` +//! +//! Or, fully written out: +//! +//! ``` +//! use std::{future::Future, net::SocketAddr, pin::Pin, task::{self, Poll}}; +//! use hyper::{service::Service, Uri}; +//! use tokio::net::TcpStream; +//! +//! #[derive(Clone)] +//! struct LocalConnector; +//! +//! impl Service<Uri> for LocalConnector { +//! type Response = TcpStream; +//! type Error = std::io::Error; +//! // We can't "name" an `async` generated future. +//! type Future = Pin<Box< +//! dyn Future<Output = Result<Self::Response, Self::Error>> + Send +//! >>; +//! +//! fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { +//! // This connector is always ready, but others might not be. +//! Poll::Ready(Ok(())) +//! } +//! +//! fn call(&mut self, _: Uri) -> Self::Future { +//! Box::pin(TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 1337)))) +//! } +//! } +//! ``` +//! +//! It's worth noting that for `TcpStream`s, the [`HttpConnector`][] is a +//! better starting place to extend from. +//! +//! Using either of the above connector examples, it can be used with the +//! `Client` like this: +//! +//! ``` +//! # let connector = hyper::client::HttpConnector::new(); +//! // let connector = ... +//! +//! let client = hyper::Client::builder() +//! .build::<_, hyper::Body>(connector); +//! ``` +//! +//! +//! [`HttpConnector`]: HttpConnector +//! [`Service`]: crate::service::Service +//! [`Uri`]: http::Uri +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Connection`]: Connection +use std::fmt; + +use ::http::Response; + +#[cfg(feature = "tcp")] +pub mod dns; +#[cfg(feature = "tcp")] +mod http; +#[cfg(feature = "tcp")] +pub use self::http::{HttpConnector, HttpInfo}; +pub use self::sealed::Connect; + +/// Describes a type returned by a connector. +pub trait Connection { + /// Return metadata describing the connection. + fn connected(&self) -> Connected; +} + +/// Extra information about the connected transport. +/// +/// This can be used to inform recipients about things like if ALPN +/// was used, or if connected to an HTTP proxy. +#[derive(Debug)] +pub struct Connected { + pub(super) alpn: Alpn, + pub(super) is_proxied: bool, + pub(super) extra: Option<Extra>, +} + +pub(super) struct Extra(Box<dyn ExtraInner>); + +#[derive(Clone, Copy, Debug, PartialEq)] +pub(super) enum Alpn { + H2, + None, +} + +impl Connected { + /// Create new `Connected` type with empty metadata. + pub fn new() -> Connected { + Connected { + alpn: Alpn::None, + is_proxied: false, + extra: None, + } + } + + /// Set whether the connected transport is to an HTTP proxy. + /// + /// This setting will affect if HTTP/1 requests written on the transport + /// will have the request-target in absolute-form or origin-form: + /// + /// - When `proxy(false)`: + /// + /// ```http + /// GET /guide HTTP/1.1 + /// ``` + /// + /// - When `proxy(true)`: + /// + /// ```http + /// GET http://hyper.rs/guide HTTP/1.1 + /// ``` + /// + /// Default is `false`. + pub fn proxy(mut self, is_proxied: bool) -> Connected { + self.is_proxied = is_proxied; + self + } + + /// Set extra connection information to be set in the extensions of every `Response`. + pub fn extra<T: Clone + Send + Sync + 'static>(mut self, extra: T) -> Connected { + if let Some(prev) = self.extra { + self.extra = Some(Extra(Box::new(ExtraChain(prev.0, extra)))); + } else { + self.extra = Some(Extra(Box::new(ExtraEnvelope(extra)))); + } + self + } + + /// Set that the connected transport negotiated HTTP/2 as it's + /// next protocol. + pub fn negotiated_h2(mut self) -> Connected { + self.alpn = Alpn::H2; + self + } + + // Don't public expose that `Connected` is `Clone`, unsure if we want to + // keep that contract... + pub(super) fn clone(&self) -> Connected { + Connected { + alpn: self.alpn.clone(), + is_proxied: self.is_proxied, + extra: self.extra.clone(), + } + } +} + +// ===== impl Extra ===== + +impl Extra { + pub(super) fn set(&self, res: &mut Response<crate::Body>) { + self.0.set(res); + } +} + +impl Clone for Extra { + fn clone(&self) -> Extra { + Extra(self.0.clone_box()) + } +} + +impl fmt::Debug for Extra { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extra").finish() + } +} + +trait ExtraInner: Send + Sync { + fn clone_box(&self) -> Box<dyn ExtraInner>; + fn set(&self, res: &mut Response<crate::Body>); +} + +// This indirection allows the `Connected` to have a type-erased "extra" value, +// while that type still knows its inner extra type. This allows the correct +// TypeId to be used when inserting into `res.extensions_mut()`. +#[derive(Clone)] +struct ExtraEnvelope<T>(T); + +impl<T> ExtraInner for ExtraEnvelope<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Response<crate::Body>) { + res.extensions_mut().insert(self.0.clone()); + } +} + +struct ExtraChain<T>(Box<dyn ExtraInner>, T); + +impl<T: Clone> Clone for ExtraChain<T> { + fn clone(&self) -> Self { + ExtraChain(self.0.clone_box(), self.1.clone()) + } +} + +impl<T> ExtraInner for ExtraChain<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Response<crate::Body>) { + self.0.set(res); + res.extensions_mut().insert(self.1.clone()); + } +} + +pub(super) mod sealed { + use std::error::Error as StdError; + + use ::http::Uri; + use tokio::io::{AsyncRead, AsyncWrite}; + + use super::Connection; + use crate::common::{Future, Unpin}; + + /// Connect to a destination, returning an IO transport. + /// + /// A connector receives a [`Uri`](::http::Uri) and returns a `Future` of the + /// ready connection. + /// + /// # Trait Alias + /// + /// This is really just an *alias* for the `tower::Service` trait, with + /// additional bounds set for convenience *inside* hyper. You don't actually + /// implement this trait, but `tower::Service<Uri>` instead. + // The `Sized` bound is to prevent creating `dyn Connect`, since they cannot + // fit the `Connect` bounds because of the blanket impl for `Service`. + pub trait Connect: Sealed + Sized { + #[doc(hidden)] + type _Svc: ConnectSvc; + #[doc(hidden)] + fn connect(self, internal_only: Internal, dst: Uri) -> <Self::_Svc as ConnectSvc>::Future; + } + + pub trait ConnectSvc { + type Connection: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static; + type Error: Into<Box<dyn StdError + Send + Sync>>; + type Future: Future<Output = Result<Self::Connection, Self::Error>> + Unpin + Send + 'static; + + fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future; + } + + impl<S, T> Connect for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + type _Svc = S; + + fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot<S, Uri> { + crate::service::oneshot(self, dst) + } + } + + impl<S, T> ConnectSvc for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + type Connection = T; + type Error = S::Error; + type Future = crate::service::Oneshot<S, Uri>; + + fn connect(self, _: Internal, dst: Uri) -> Self::Future { + crate::service::oneshot(self, dst) + } + } + + impl<S, T> Sealed for S + where + S: tower_service::Service<Uri, Response = T> + Send, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + } + + pub trait Sealed {} + #[allow(missing_debug_implementations)] + pub struct Internal; +} + +#[cfg(test)] +mod tests { + use super::Connected; + + #[derive(Clone, Debug, PartialEq)] + struct Ex1(usize); + + #[derive(Clone, Debug, PartialEq)] + struct Ex2(&'static str); + + #[derive(Clone, Debug, PartialEq)] + struct Ex3(&'static str); + + #[test] + fn test_connected_extra() { + let c1 = Connected::new().extra(Ex1(41)); + + let mut res1 = crate::Response::new(crate::Body::empty()); + + assert_eq!(res1.extensions().get::<Ex1>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut res1); + + assert_eq!(res1.extensions().get::<Ex1>(), Some(&Ex1(41))); + } + + #[test] + fn test_connected_extra_chain() { + // If a user composes connectors and at each stage, there's "extra" + // info to attach, it shouldn't override the previous extras. + + let c1 = Connected::new() + .extra(Ex1(45)) + .extra(Ex2("zoom")) + .extra(Ex3("pew pew")); + + let mut res1 = crate::Response::new(crate::Body::empty()); + + assert_eq!(res1.extensions().get::<Ex1>(), None); + assert_eq!(res1.extensions().get::<Ex2>(), None); + assert_eq!(res1.extensions().get::<Ex3>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut res1); + + assert_eq!(res1.extensions().get::<Ex1>(), Some(&Ex1(45))); + assert_eq!(res1.extensions().get::<Ex2>(), Some(&Ex2("zoom"))); + assert_eq!(res1.extensions().get::<Ex3>(), Some(&Ex3("pew pew"))); + + // Just like extensions, inserting the same type overrides previous type. + let c2 = Connected::new() + .extra(Ex1(33)) + .extra(Ex2("hiccup")) + .extra(Ex1(99)); + + let mut res2 = crate::Response::new(crate::Body::empty()); + + c2.extra.as_ref().expect("c2 extra").set(&mut res2); + + assert_eq!(res2.extensions().get::<Ex1>(), Some(&Ex1(99))); + assert_eq!(res2.extensions().get::<Ex2>(), Some(&Ex2("hiccup"))); + } +} diff --git a/third_party/rust/hyper/src/client/dispatch.rs b/third_party/rust/hyper/src/client/dispatch.rs new file mode 100644 index 0000000000..9a580f85ac --- /dev/null +++ b/third_party/rust/hyper/src/client/dispatch.rs @@ -0,0 +1,394 @@ +use futures_util::future; +use tokio::sync::{mpsc, oneshot}; + +use crate::common::{task, Future, Pin, Poll}; + +pub type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>; +pub type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>; + +pub fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) { + let (tx, rx) = mpsc::unbounded_channel(); + let (giver, taker) = want::new(); + let tx = Sender { + buffered_once: false, + giver, + inner: tx, + }; + let rx = Receiver { inner: rx, taker }; + (tx, rx) +} + +/// A bounded sender of requests and callbacks for when responses are ready. +/// +/// While the inner sender is unbounded, the Giver is used to determine +/// if the Receiver is ready for another request. +pub struct Sender<T, U> { + /// One message is always allowed, even if the Receiver hasn't asked + /// for it yet. This boolean keeps track of whether we've sent one + /// without notice. + buffered_once: bool, + /// The Giver helps watch that the the Receiver side has been polled + /// when the queue is empty. This helps us know when a request and + /// response have been fully processed, and a connection is ready + /// for more. + giver: want::Giver, + /// Actually bounded by the Giver, plus `buffered_once`. + inner: mpsc::UnboundedSender<Envelope<T, U>>, +} + +/// An unbounded version. +/// +/// Cannot poll the Giver, but can still use it to determine if the Receiver +/// has been dropped. However, this version can be cloned. +pub struct UnboundedSender<T, U> { + /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked. + giver: want::SharedGiver, + inner: mpsc::UnboundedSender<Envelope<T, U>>, +} + +impl<T, U> Sender<T, U> { + pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.giver + .poll_want(cx) + .map_err(|_| crate::Error::new_closed()) + } + + pub fn is_ready(&self) -> bool { + self.giver.is_wanting() + } + + pub fn is_closed(&self) -> bool { + self.giver.is_canceled() + } + + fn can_send(&mut self) -> bool { + if self.giver.give() || !self.buffered_once { + // If the receiver is ready *now*, then of course we can send. + // + // If the receiver isn't ready yet, but we don't have anything + // in the channel yet, then allow one message. + self.buffered_once = true; + true + } else { + false + } + } + + pub fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { + if !self.can_send() { + return Err(val); + } + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::Retry(tx))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } + + pub fn send(&mut self, val: T) -> Result<Promise<U>, T> { + if !self.can_send() { + return Err(val); + } + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::NoRetry(tx))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } + + pub fn unbound(self) -> UnboundedSender<T, U> { + UnboundedSender { + giver: self.giver.shared(), + inner: self.inner, + } + } +} + +impl<T, U> UnboundedSender<T, U> { + pub fn is_ready(&self) -> bool { + !self.giver.is_canceled() + } + + pub fn is_closed(&self) -> bool { + self.giver.is_canceled() + } + + pub fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::Retry(tx))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } +} + +impl<T, U> Clone for UnboundedSender<T, U> { + fn clone(&self) -> Self { + UnboundedSender { + giver: self.giver.clone(), + inner: self.inner.clone(), + } + } +} + +pub struct Receiver<T, U> { + inner: mpsc::UnboundedReceiver<Envelope<T, U>>, + taker: want::Taker, +} + +impl<T, U> Receiver<T, U> { + pub(crate) fn poll_next( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<(T, Callback<T, U>)>> { + match self.inner.poll_recv(cx) { + Poll::Ready(item) => { + Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped"))) + } + Poll::Pending => { + self.taker.want(); + Poll::Pending + } + } + } + + pub(crate) fn close(&mut self) { + self.taker.cancel(); + self.inner.close(); + } + + pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> { + match self.inner.try_recv() { + Ok(mut env) => env.0.take(), + Err(_) => None, + } + } +} + +impl<T, U> Drop for Receiver<T, U> { + fn drop(&mut self) { + // Notify the giver about the closure first, before dropping + // the mpsc::Receiver. + self.taker.cancel(); + } +} + +struct Envelope<T, U>(Option<(T, Callback<T, U>)>); + +impl<T, U> Drop for Envelope<T, U> { + fn drop(&mut self) { + if let Some((val, cb)) = self.0.take() { + cb.send(Err(( + crate::Error::new_canceled().with("connection closed"), + Some(val), + ))); + } + } +} + +pub enum Callback<T, U> { + Retry(oneshot::Sender<Result<U, (crate::Error, Option<T>)>>), + NoRetry(oneshot::Sender<Result<U, crate::Error>>), +} + +impl<T, U> Callback<T, U> { + pub(crate) fn is_canceled(&self) -> bool { + match *self { + Callback::Retry(ref tx) => tx.is_closed(), + Callback::NoRetry(ref tx) => tx.is_closed(), + } + } + + pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> { + match *self { + Callback::Retry(ref mut tx) => tx.poll_closed(cx), + Callback::NoRetry(ref mut tx) => tx.poll_closed(cx), + } + } + + pub(crate) fn send(self, val: Result<U, (crate::Error, Option<T>)>) { + match self { + Callback::Retry(tx) => { + let _ = tx.send(val); + } + Callback::NoRetry(tx) => { + let _ = tx.send(val.map_err(|e| e.0)); + } + } + } + + pub(crate) fn send_when( + self, + mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, + ) -> impl Future<Output = ()> { + let mut cb = Some(self); + + // "select" on this callback being canceled, and the future completing + future::poll_fn(move |cx| { + match Pin::new(&mut when).poll(cx) { + Poll::Ready(Ok(res)) => { + cb.take().expect("polled after complete").send(Ok(res)); + Poll::Ready(()) + } + Poll::Pending => { + // check if the callback is canceled + ready!(cb.as_mut().unwrap().poll_canceled(cx)); + trace!("send_when canceled"); + Poll::Ready(()) + } + Poll::Ready(Err(err)) => { + cb.take().expect("polled after complete").send(Err(err)); + Poll::Ready(()) + } + } + }) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + extern crate test; + + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use super::{channel, Callback, Receiver}; + + #[derive(Debug)] + struct Custom(i32); + + impl<T, U> Future for Receiver<T, U> { + type Output = Option<(T, Callback<T, U>)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.poll_next(cx) + } + } + + /// Helper to check if the future is ready after polling once. + struct PollOnce<'a, F>(&'a mut F); + + impl<F, T> Future for PollOnce<'_, F> + where + F: Future<Output = T> + Unpin, + { + type Output = Option<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(_) => Poll::Ready(Some(())), + Poll::Pending => Poll::Ready(None), + } + } + } + + #[tokio::test] + async fn drop_receiver_sends_cancel_errors() { + let _ = pretty_env_logger::try_init(); + + let (mut tx, mut rx) = channel::<Custom, ()>(); + + // must poll once for try_send to succeed + assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); + + let promise = tx.try_send(Custom(43)).unwrap(); + drop(rx); + + let fulfilled = promise.await; + let err = fulfilled + .expect("fulfilled") + .expect_err("promise should error"); + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + e => panic!("expected Error::Cancel(_), found {:?}", e), + } + } + + #[tokio::test] + async fn sender_checks_for_want_on_send() { + let (mut tx, mut rx) = channel::<Custom, ()>(); + + // one is allowed to buffer, second is rejected + let _ = tx.try_send(Custom(1)).expect("1 buffered"); + tx.try_send(Custom(2)).expect_err("2 not ready"); + + assert!(PollOnce(&mut rx).await.is_some(), "rx once"); + + // Even though 1 has been popped, only 1 could be buffered for the + // lifetime of the channel. + tx.try_send(Custom(2)).expect_err("2 still not ready"); + + assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); + + let _ = tx.try_send(Custom(2)).expect("2 ready"); + } + + #[test] + fn unbounded_sender_doesnt_bound_on_want() { + let (tx, rx) = channel::<Custom, ()>(); + let mut tx = tx.unbound(); + + let _ = tx.try_send(Custom(1)).unwrap(); + let _ = tx.try_send(Custom(2)).unwrap(); + let _ = tx.try_send(Custom(3)).unwrap(); + + drop(rx); + + let _ = tx.try_send(Custom(4)).unwrap_err(); + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_throughput(b: &mut test::Bencher) { + use crate::{Body, Request, Response}; + + let mut rt = tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .unwrap(); + let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>(); + + b.iter(move || { + let _ = tx.send(Request::default()).unwrap(); + rt.block_on(async { + loop { + let poll_once = PollOnce(&mut rx); + let opt = poll_once.await; + if opt.is_none() { + break; + } + } + }); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_not_ready(b: &mut test::Bencher) { + let mut rt = tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .unwrap(); + let (_tx, mut rx) = channel::<i32, ()>(); + b.iter(move || { + rt.block_on(async { + let poll_once = PollOnce(&mut rx); + assert!(poll_once.await.is_none()); + }); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_cancel(b: &mut test::Bencher) { + let (_tx, mut rx) = channel::<i32, ()>(); + + b.iter(move || { + rx.taker.cancel(); + }) + } +} diff --git a/third_party/rust/hyper/src/client/mod.rs b/third_party/rust/hyper/src/client/mod.rs new file mode 100644 index 0000000000..42fad34573 --- /dev/null +++ b/third_party/rust/hyper/src/client/mod.rs @@ -0,0 +1,1224 @@ +//! HTTP Client +//! +//! There are two levels of APIs provided for construct HTTP clients: +//! +//! - The higher-level [`Client`](Client) type. +//! - The lower-level [`conn`](conn) module. +//! +//! # Client +//! +//! The [`Client`](Client) is the main way to send HTTP requests to a server. +//! The default `Client` provides these things on top of the lower-level API: +//! +//! - A default **connector**, able to resolve hostnames and connect to +//! destinations over plain-text TCP. +//! - A **pool** of existing connections, allowing better performance when +//! making multiple requests to the same hostname. +//! - Automatic setting of the `Host` header, based on the request `Uri`. +//! - Automatic request **retries** when a pooled connection is closed by the +//! server before any bytes have been written. +//! +//! Many of these features can configured, by making use of +//! [`Client::builder`](Client::builder). +//! +//! ## Example +//! +//! For a small example program simply fetching a URL, take a look at the +//! [full client example](https://github.com/hyperium/hyper/blob/master/examples/client.rs). +//! +//! ``` +//! use hyper::{body::HttpBody as _, Client, Uri}; +//! +//! # #[cfg(feature = "tcp")] +//! # async fn fetch_httpbin() -> hyper::Result<()> { +//! let client = Client::new(); +//! +//! // Make a GET /ip to 'http://httpbin.org' +//! let res = client.get(Uri::from_static("http://httpbin.org/ip")).await?; +//! +//! // And then, if the request gets a response... +//! println!("status: {}", res.status()); +//! +//! // Concatenate the body stream into a single buffer... +//! let buf = hyper::body::to_bytes(res).await?; +//! +//! println!("body: {:?}", buf); +//! # Ok(()) +//! # } +//! # fn main () {} +//! ``` + +use std::fmt; +use std::mem; +use std::time::Duration; + +use futures_channel::oneshot; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use http::header::{HeaderValue, HOST}; +use http::uri::Scheme; +use http::{Method, Request, Response, Uri, Version}; + +use self::connect::{sealed::Connect, Alpn, Connected, Connection}; +use self::pool::{Key as PoolKey, Pool, Poolable, Pooled, Reservation}; +use crate::body::{Body, Payload}; +use crate::common::{lazy as hyper_lazy, task, BoxSendFuture, Executor, Future, Lazy, Pin, Poll}; + +#[cfg(feature = "tcp")] +pub use self::connect::HttpConnector; + +pub mod conn; +pub mod connect; +pub(crate) mod dispatch; +mod pool; +pub mod service; +#[cfg(test)] +mod tests; + +/// A Client to make outgoing HTTP requests. +pub struct Client<C, B = Body> { + config: Config, + conn_builder: conn::Builder, + connector: C, + pool: Pool<PoolClient<B>>, +} + +#[derive(Clone, Copy, Debug)] +struct Config { + retry_canceled_requests: bool, + set_host: bool, + ver: Ver, +} + +/// A `Future` that will resolve to an HTTP Response. +/// +/// This is returned by `Client::request` (and `Client::get`). +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + inner: Pin<Box<dyn Future<Output = crate::Result<Response<Body>>> + Send>>, +} + +// ===== impl Client ===== + +#[cfg(feature = "tcp")] +impl Client<HttpConnector, Body> { + /// Create a new Client with the default [config](Builder). + /// + /// # Note + /// + /// The default connector does **not** handle TLS. Speaking to `https` + /// destinations will require [configuring a connector that implements + /// TLS](https://hyper.rs/guides/client/configuration). + #[inline] + pub fn new() -> Client<HttpConnector, Body> { + Builder::default().build_http() + } +} + +#[cfg(feature = "tcp")] +impl Default for Client<HttpConnector, Body> { + fn default() -> Client<HttpConnector, Body> { + Client::new() + } +} + +impl Client<(), Body> { + /// Create a builder to configure a new `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use std::time::Duration; + /// use hyper::Client; + /// + /// let client = Client::builder() + /// .pool_idle_timeout(Duration::from_secs(30)) + /// .http2_only(true) + /// .build_http(); + /// # let infer: Client<_, hyper::Body> = client; + /// # drop(infer); + /// # } + /// # fn main() {} + /// ``` + #[inline] + pub fn builder() -> Builder { + Builder::default() + } +} + +impl<C, B> Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: Payload + Send + 'static, + B::Data: Send, +{ + /// Send a `GET` request to the supplied `Uri`. + /// + /// # Note + /// + /// This requires that the `Payload` type have a `Default` implementation. + /// It *should* return an "empty" version of itself, such that + /// `Payload::is_end_stream` is `true`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use hyper::{Client, Uri}; + /// + /// let client = Client::new(); + /// + /// let future = client.get(Uri::from_static("http://httpbin.org/ip")); + /// # } + /// # fn main() {} + /// ``` + pub fn get(&self, uri: Uri) -> ResponseFuture + where + B: Default, + { + let body = B::default(); + if !body.is_end_stream() { + warn!("default Payload used for get() does not return true for is_end_stream"); + } + + let mut req = Request::new(body); + *req.uri_mut() = uri; + self.request(req) + } + + /// Send a constructed `Request` using this `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use hyper::{Body, Client, Request}; + /// + /// let client = Client::new(); + /// + /// let req = Request::builder() + /// .method("POST") + /// .uri("http://httpin.org/post") + /// .body(Body::from("Hallo!")) + /// .expect("request builder"); + /// + /// let future = client.request(req); + /// # } + /// # fn main() {} + /// ``` + pub fn request(&self, mut req: Request<B>) -> ResponseFuture { + let is_http_connect = req.method() == Method::CONNECT; + match req.version() { + Version::HTTP_11 => (), + Version::HTTP_10 => { + if is_http_connect { + warn!("CONNECT is not allowed for HTTP/1.0"); + return ResponseFuture::new(Box::new(future::err( + crate::Error::new_user_unsupported_request_method(), + ))); + } + } + other_h2 @ Version::HTTP_2 => { + if self.config.ver != Ver::Http2 { + return ResponseFuture::error_version(other_h2); + } + } + // completely unsupported HTTP version (like HTTP/0.9)! + other => return ResponseFuture::error_version(other), + }; + + let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { + Ok(s) => s, + Err(err) => { + return ResponseFuture::new(Box::new(future::err(err))); + } + }; + + ResponseFuture::new(Box::new(self.retryably_send_request(req, pool_key))) + } + + fn retryably_send_request( + &self, + req: Request<B>, + pool_key: PoolKey, + ) -> impl Future<Output = crate::Result<Response<Body>>> { + let client = self.clone(); + let uri = req.uri().clone(); + + let mut send_fut = client.send_request(req, pool_key.clone()); + future::poll_fn(move |cx| loop { + match ready!(Pin::new(&mut send_fut).poll(cx)) { + Ok(resp) => return Poll::Ready(Ok(resp)), + Err(ClientError::Normal(err)) => return Poll::Ready(Err(err)), + Err(ClientError::Canceled { + connection_reused, + mut req, + reason, + }) => { + if !client.config.retry_canceled_requests || !connection_reused { + // if client disabled, don't retry + // a fresh connection means we definitely can't retry + return Poll::Ready(Err(reason)); + } + + trace!( + "unstarted request canceled, trying again (reason={:?})", + reason + ); + *req.uri_mut() = uri.clone(); + send_fut = client.send_request(req, pool_key.clone()); + } + } + }) + } + + fn send_request( + &self, + mut req: Request<B>, + pool_key: PoolKey, + ) -> impl Future<Output = Result<Response<Body>, ClientError<B>>> + Unpin { + let conn = self.connection_for(pool_key); + + let set_host = self.config.set_host; + let executor = self.conn_builder.exec.clone(); + conn.and_then(move |mut pooled| { + if pooled.is_http1() { + if set_host { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = uri.port() { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + } + + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } else if pooled.conn_info.is_proxied { + absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + }; + } else if req.method() == Method::CONNECT { + debug!("client does not support CONNECT requests over HTTP2"); + return Either::Left(future::err(ClientError::Normal( + crate::Error::new_user_unsupported_request_method(), + ))); + } + + let fut = pooled + .send_request_retryable(req) + .map_err(ClientError::map_with_reused(pooled.is_reused())); + + // If the Connector included 'extra' info, add to Response... + let extra_info = pooled.conn_info.extra.clone(); + let fut = fut.map_ok(move |mut res| { + if let Some(extra) = extra_info { + extra.set(&mut res); + } + res + }); + + // As of futures@0.1.21, there is a race condition in the mpsc + // channel, such that sending when the receiver is closing can + // result in the message being stuck inside the queue. It won't + // ever notify until the Sender side is dropped. + // + // To counteract this, we must check if our senders 'want' channel + // has been closed after having tried to send. If so, error out... + if pooled.is_closed() { + return Either::Right(Either::Left(fut)); + } + + Either::Right(Either::Right(fut.map_ok(move |mut res| { + // If pooled is HTTP/2, we can toss this reference immediately. + // + // when pooled is dropped, it will try to insert back into the + // pool. To delay that, spawn a future that completes once the + // sender is ready again. + // + // This *should* only be once the related `Connection` has polled + // for a new request to start. + // + // It won't be ready if there is a body to stream. + if pooled.is_http2() || !pooled.is_pool_enabled() || pooled.is_ready() { + drop(pooled); + } else if !res.body().is_end_stream() { + let (delayed_tx, delayed_rx) = oneshot::channel(); + res.body_mut().delayed_eof(delayed_rx); + let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(move |_| { + // At this point, `pooled` is dropped, and had a chance + // to insert into the pool (if conn was idle) + drop(delayed_tx); + }); + + executor.execute(on_idle); + } else { + // There's no body to delay, but the connection isn't + // ready yet. Only re-insert when it's ready + let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(|_| ()); + + executor.execute(on_idle); + } + res + }))) + }) + } + + fn connection_for( + &self, + pool_key: PoolKey, + ) -> impl Future<Output = Result<Pooled<PoolClient<B>>, ClientError<B>>> { + // This actually races 2 different futures to try to get a ready + // connection the fastest, and to reduce connection churn. + // + // - If the pool has an idle connection waiting, that's used + // immediately. + // - Otherwise, the Connector is asked to start connecting to + // the destination Uri. + // - Meanwhile, the pool Checkout is watching to see if any other + // request finishes and tries to insert an idle connection. + // - If a new connection is started, but the Checkout wins after + // (an idle connection became available first), the started + // connection future is spawned into the runtime to complete, + // and then be inserted into the pool as an idle connection. + let checkout = self.pool.checkout(pool_key.clone()); + let connect = self.connect_to(pool_key); + + let executor = self.conn_builder.exec.clone(); + // The order of the `select` is depended on below... + future::select(checkout, connect).then(move |either| match either { + // Checkout won, connect future may have been started or not. + // + // If it has, let it finish and insert back into the pool, + // so as to not waste the socket... + Either::Left((Ok(checked_out), connecting)) => { + // This depends on the `select` above having the correct + // order, such that if the checkout future were ready + // immediately, the connect future will never have been + // started. + // + // If it *wasn't* ready yet, then the connect future will + // have been started... + if connecting.started() { + let bg = connecting + .map_err(|err| { + trace!("background connect error: {}", err); + }) + .map(|_pooled| { + // dropping here should just place it in + // the Pool for us... + }); + // An execute error here isn't important, we're just trying + // to prevent a waste of a socket... + executor.execute(bg); + } + Either::Left(future::ok(checked_out)) + } + // Connect won, checkout can just be dropped. + Either::Right((Ok(connected), _checkout)) => Either::Left(future::ok(connected)), + // Either checkout or connect could get canceled: + // + // 1. Connect is canceled if this is HTTP/2 and there is + // an outstanding HTTP/2 connecting task. + // 2. Checkout is canceled if the pool cannot deliver an + // idle connection reliably. + // + // In both cases, we should just wait for the other future. + Either::Left((Err(err), connecting)) => Either::Right(Either::Left({ + if err.is_canceled() { + Either::Left(connecting.map_err(ClientError::Normal)) + } else { + Either::Right(future::err(ClientError::Normal(err))) + } + })), + Either::Right((Err(err), checkout)) => Either::Right(Either::Right({ + if err.is_canceled() { + Either::Left(checkout.map_err(ClientError::Normal)) + } else { + Either::Right(future::err(ClientError::Normal(err))) + } + })), + }) + } + + fn connect_to( + &self, + pool_key: PoolKey, + ) -> impl Lazy<Output = crate::Result<Pooled<PoolClient<B>>>> + Unpin { + let executor = self.conn_builder.exec.clone(); + let pool = self.pool.clone(); + let mut conn_builder = self.conn_builder.clone(); + let ver = self.config.ver; + let is_ver_h2 = ver == Ver::Http2; + let connector = self.connector.clone(); + let dst = domain_as_uri(pool_key.clone()); + hyper_lazy(move || { + // Try to take a "connecting lock". + // + // If the pool_key is for HTTP/2, and there is already a + // connection being established, then this can't take a + // second lock. The "connect_to" future is Canceled. + let connecting = match pool.connecting(&pool_key, ver) { + Some(lock) => lock, + None => { + let canceled = + crate::Error::new_canceled().with("HTTP/2 connection in progress"); + return Either::Right(future::err(canceled)); + } + }; + Either::Left( + connector + .connect(connect::sealed::Internal, dst) + .map_err(crate::Error::new_connect) + .and_then(move |io| { + let connected = io.connected(); + // If ALPN is h2 and we aren't http2_only already, + // then we need to convert our pool checkout into + // a single HTTP2 one. + let connecting = if connected.alpn == Alpn::H2 && !is_ver_h2 { + match connecting.alpn_h2(&pool) { + Some(lock) => { + trace!("ALPN negotiated h2, updating pool"); + lock + } + None => { + // Another connection has already upgraded, + // the pool checkout should finish up for us. + let canceled = crate::Error::new_canceled() + .with("ALPN upgraded to HTTP/2"); + return Either::Right(future::err(canceled)); + } + } + } else { + connecting + }; + let is_h2 = is_ver_h2 || connected.alpn == Alpn::H2; + Either::Left(Box::pin( + conn_builder + .http2_only(is_h2) + .handshake(io) + .and_then(move |(tx, conn)| { + trace!( + "handshake complete, spawning background dispatcher task" + ); + executor.execute( + conn.map_err(|e| debug!("client connection error: {}", e)) + .map(|_| ()), + ); + + // Wait for 'conn' to ready up before we + // declare this tx as usable + tx.when_ready() + }) + .map_ok(move |tx| { + pool.pooled( + connecting, + PoolClient { + conn_info: connected, + tx: if is_h2 { + PoolTx::Http2(tx.into_http2()) + } else { + PoolTx::Http1(tx) + }, + }, + ) + }), + )) + }), + ) + }) + } +} + +impl<C, B> tower_service::Service<Request<B>> for Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: Payload + Send + 'static, + B::Data: Send, +{ + type Response = Response<Body>; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.request(req) + } +} + +impl<C: Clone, B> Clone for Client<C, B> { + fn clone(&self) -> Client<C, B> { + Client { + config: self.config.clone(), + conn_builder: self.conn_builder.clone(), + connector: self.connector.clone(), + pool: self.pool.clone(), + } + } +} + +impl<C, B> fmt::Debug for Client<C, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} + +// ===== impl ResponseFuture ===== + +impl ResponseFuture { + fn new(fut: Box<dyn Future<Output = crate::Result<Response<Body>>> + Send>) -> Self { + Self { inner: fut.into() } + } + + fn error_version(ver: Version) -> Self { + warn!("Request has unsupported version \"{:?}\"", ver); + ResponseFuture::new(Box::new(future::err( + crate::Error::new_user_unsupported_version(), + ))) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Future<Response>") + } +} + +impl Future for ResponseFuture { + type Output = crate::Result<Response<Body>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.inner).poll(cx) + } +} + +// ===== impl PoolClient ===== + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +struct PoolClient<B> { + conn_info: Connected, + tx: PoolTx<B>, +} + +enum PoolTx<B> { + Http1(conn::SendRequest<B>), + Http2(conn::Http2SendRequest<B>), +} + +impl<B> PoolClient<B> { + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match self.tx { + PoolTx::Http1(ref mut tx) => tx.poll_ready(cx), + PoolTx::Http2(_) => Poll::Ready(Ok(())), + } + } + + fn is_http1(&self) -> bool { + !self.is_http2() + } + + fn is_http2(&self) -> bool { + match self.tx { + PoolTx::Http1(_) => false, + PoolTx::Http2(_) => true, + } + } + + fn is_ready(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_ready(), + PoolTx::Http2(ref tx) => tx.is_ready(), + } + } + + fn is_closed(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_closed(), + PoolTx::Http2(ref tx) => tx.is_closed(), + } + } +} + +impl<B: Payload + 'static> PoolClient<B> { + fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + where + B: Send, + { + match self.tx { + PoolTx::Http1(ref mut tx) => Either::Left(tx.send_request_retryable(req)), + PoolTx::Http2(ref mut tx) => Either::Right(tx.send_request_retryable(req)), + } + } +} + +impl<B> Poolable for PoolClient<B> +where + B: Send + 'static, +{ + fn is_open(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_ready(), + PoolTx::Http2(ref tx) => tx.is_ready(), + } + } + + fn reserve(self) -> Reservation<Self> { + match self.tx { + PoolTx::Http1(tx) => Reservation::Unique(PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http1(tx), + }), + PoolTx::Http2(tx) => { + let b = PoolClient { + conn_info: self.conn_info.clone(), + tx: PoolTx::Http2(tx.clone()), + }; + let a = PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http2(tx), + }; + Reservation::Shared(a, b) + } + } + } + + fn can_share(&self) -> bool { + self.is_http2() + } +} + +// ===== impl ClientError ===== + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +enum ClientError<B> { + Normal(crate::Error), + Canceled { + connection_reused: bool, + req: Request<B>, + reason: crate::Error, + }, +} + +impl<B> ClientError<B> { + fn map_with_reused(conn_reused: bool) -> impl Fn((crate::Error, Option<Request<B>>)) -> Self { + move |(err, orig_req)| { + if let Some(req) = orig_req { + ClientError::Canceled { + connection_reused: conn_reused, + reason: err, + req, + } + } else { + ClientError::Normal(err) + } + } + } +} + +/// A marker to identify what version a pooled connection is. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +enum Ver { + Auto, + Http2, +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if log_enabled!(::log::Level::Warn) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} + +fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<PoolKey> { + let uri_clone = uri.clone(); + match (uri_clone.scheme(), uri_clone.authority()) { + (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), + (None, Some(auth)) if is_http_connect => { + let scheme = match auth.port_u16() { + Some(443) => { + set_scheme(uri, Scheme::HTTPS); + Scheme::HTTPS + } + _ => { + set_scheme(uri, Scheme::HTTP); + Scheme::HTTP + } + }; + Ok((scheme, auth.clone())) + } + _ => { + debug!("Client requires absolute-form URIs, received: {:?}", uri); + Err(crate::Error::new_user_absolute_uri_required()) + } + } +} + +fn domain_as_uri((scheme, auth): PoolKey) -> Uri { + http::uri::Builder::new() + .scheme(scheme) + .authority(auth) + .path_and_query("/") + .build() + .expect("domain is valid Uri") +} + +fn set_scheme(uri: &mut Uri, scheme: Scheme) { + debug_assert!( + uri.scheme().is_none(), + "set_scheme expects no existing scheme" + ); + let old = mem::replace(uri, Uri::default()); + let mut parts: ::http::uri::Parts = old.into(); + parts.scheme = Some(scheme); + parts.path_and_query = Some("/".parse().expect("slash is a valid path")); + *uri = Uri::from_parts(parts).expect("scheme is valid"); +} + +/// A builder to configure a new [`Client`](Client). +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "runtime")] +/// # fn run () { +/// use std::time::Duration; +/// use hyper::Client; +/// +/// let client = Client::builder() +/// .pool_idle_timeout(Duration::from_secs(30)) +/// .http2_only(true) +/// .build_http(); +/// # let infer: Client<_, hyper::Body> = client; +/// # drop(infer); +/// # } +/// # fn main() {} +/// ``` +#[derive(Clone)] +pub struct Builder { + client_config: Config, + conn_builder: conn::Builder, + pool_config: pool::Config, +} + +impl Default for Builder { + fn default() -> Self { + Self { + client_config: Config { + retry_canceled_requests: true, + set_host: true, + ver: Ver::Auto, + }, + conn_builder: conn::Builder::new(), + pool_config: pool::Config { + idle_timeout: Some(Duration::from_secs(90)), + max_idle_per_host: std::usize::MAX, + }, + } + } +} + +impl Builder { + #[doc(hidden)] + #[deprecated( + note = "name is confusing, to disable the connection pool, call pool_max_idle_per_host(0)" + )] + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + if !val { + // disable + self.pool_max_idle_per_host(0) + } else if self.pool_config.max_idle_per_host == 0 { + // enable + self.pool_max_idle_per_host(std::usize::MAX) + } else { + // already enabled + self + } + } + + #[doc(hidden)] + #[deprecated(note = "renamed to `pool_idle_timeout`")] + pub fn keep_alive_timeout<D>(&mut self, val: D) -> &mut Self + where + D: Into<Option<Duration>>, + { + self.pool_idle_timeout(val) + } + + /// Set an optional timeout for idle sockets being kept-alive. + /// + /// Pass `None` to disable timeout. + /// + /// Default is 90 seconds. + pub fn pool_idle_timeout<D>(&mut self, val: D) -> &mut Self + where + D: Into<Option<Duration>>, + { + self.pool_config.idle_timeout = val.into(); + self + } + + #[doc(hidden)] + #[deprecated(note = "renamed to `pool_max_idle_per_host`")] + pub fn max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + /// Sets the maximum idle connection per host allowed in the pool. + /// + /// Default is `usize::MAX` (no limit). + pub fn pool_max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + // HTTP/1 options + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Default is `true`. + pub fn http1_writev(&mut self, val: bool) -> &mut Self { + self.conn_builder.h1_writev(val); + self + } + + /// Sets the exact size of the read buffer to *always* use. + /// + /// Note that setting this option unsets the `http1_max_buf_size` option. + /// + /// Default is an adaptive read buffer. + pub fn http1_read_buf_exact_size(&mut self, sz: usize) -> &mut Self { + self.conn_builder.h1_read_buf_exact_size(Some(sz)); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// Note that setting this option unsets the `http1_read_exact_buf_size` option. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + pub fn http1_max_buf_size(&mut self, max: usize) -> &mut Self { + self.conn_builder.h1_max_buf_size(max); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_title_case_headers(&mut self, val: bool) -> &mut Self { + self.conn_builder.h1_title_case_headers(val); + self + } + + /// Set whether the connection **must** use HTTP/2. + /// + /// The destination must either allow HTTP2 Prior Knowledge, or the + /// `Connect` should be configured to do use ALPN to upgrade to `h2` + /// as part of the connection process. This will not make the `Client` + /// utilize ALPN by itself. + /// + /// Note that setting this to true prevents HTTP/1 from being allowed. + /// + /// Default is false. + pub fn http2_only(&mut self, val: bool) -> &mut Self { + self.client_config.ver = if val { Ver::Http2 } else { Ver::Auto }; + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.conn_builder + .http2_initial_stream_window_size(sz.into()); + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + self.conn_builder + .http2_initial_connection_window_size(sz.into()); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.conn_builder.http2_adaptive_window(enabled); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.conn_builder.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.conn_builder.http2_keep_alive_timeout(timeout); + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.conn_builder.http2_keep_alive_while_idle(enabled); + self + } + + /// Set whether to retry requests that get disrupted before ever starting + /// to write. + /// + /// This means a request that is queued, and gets given an idle, reused + /// connection, and then encounters an error immediately as the idle + /// connection was found to be unusable. + /// + /// When this is set to `false`, the related `ResponseFuture` would instead + /// resolve to an `Error::Cancel`. + /// + /// Default is `true`. + #[inline] + pub fn retry_canceled_requests(&mut self, val: bool) -> &mut Self { + self.client_config.retry_canceled_requests = val; + self + } + + /// Set whether to automatically add the `Host` header to requests. + /// + /// If true, and a request does not include a `Host` header, one will be + /// added automatically, derived from the authority of the `Uri`. + /// + /// Default is `true`. + #[inline] + pub fn set_host(&mut self, val: bool) -> &mut Self { + self.client_config.set_host = val; + self + } + + /// Provide an executor to execute background `Connection` tasks. + pub fn executor<E>(&mut self, exec: E) -> &mut Self + where + E: Executor<BoxSendFuture> + Send + Sync + 'static, + { + self.conn_builder.executor(exec); + self + } + + /// Builder a client with this configuration and the default `HttpConnector`. + #[cfg(feature = "tcp")] + pub fn build_http<B>(&self) -> Client<HttpConnector, B> + where + B: Payload + Send, + B::Data: Send, + { + let mut connector = HttpConnector::new(); + if self.pool_config.is_enabled() { + connector.set_keepalive(self.pool_config.idle_timeout); + } + self.build(connector) + } + + /// Combine the configuration of this builder with a connector to create a `Client`. + pub fn build<C, B>(&self, connector: C) -> Client<C, B> + where + C: Connect + Clone, + B: Payload + Send, + B::Data: Send, + { + Client { + config: self.client_config, + conn_builder: self.conn_builder.clone(), + connector, + pool: Pool::new(self.pool_config, &self.conn_builder.exec), + } + } +} + +impl fmt::Debug for Builder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Builder") + .field("client_config", &self.client_config) + .field("conn_builder", &self.conn_builder) + .field("pool_config", &self.pool_config) + .finish() + } +} + +#[cfg(test)] +mod unit_tests { + use super::*; + + #[test] + fn set_relative_uri_with_implicit_path() { + let mut uri = "http://hyper.rs".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/"); + } + + #[test] + fn test_origin_form() { + let mut uri = "http://hyper.rs/guides".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/guides"); + + let mut uri = "http://hyper.rs/guides?foo=bar".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/guides?foo=bar"); + } + + #[test] + fn test_absolute_form() { + let mut uri = "http://hyper.rs/guides".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri.to_string(), "http://hyper.rs/guides"); + + let mut uri = "https://hyper.rs/guides".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri.to_string(), "/guides"); + } + + #[test] + fn test_authority_form() { + let _ = pretty_env_logger::try_init(); + + let mut uri = "http://hyper.rs".parse().unwrap(); + authority_form(&mut uri); + assert_eq!(uri.to_string(), "hyper.rs"); + + let mut uri = "hyper.rs".parse().unwrap(); + authority_form(&mut uri); + assert_eq!(uri.to_string(), "hyper.rs"); + } + + #[test] + fn test_extract_domain_connect_no_port() { + let mut uri = "hyper.rs".parse().unwrap(); + let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain"); + assert_eq!(scheme, *"http"); + assert_eq!(host, "hyper.rs"); + } +} diff --git a/third_party/rust/hyper/src/client/pool.rs b/third_party/rust/hyper/src/client/pool.rs new file mode 100644 index 0000000000..8c1ee24c0d --- /dev/null +++ b/third_party/rust/hyper/src/client/pool.rs @@ -0,0 +1,1013 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Mutex, Weak}; + +#[cfg(not(feature = "runtime"))] +use std::time::{Duration, Instant}; + +use futures_channel::oneshot; +#[cfg(feature = "runtime")] +use tokio::time::{Duration, Instant, Interval}; + +use super::Ver; +use crate::common::{task, Exec, Future, Pin, Poll, Unpin}; + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Pool<T> { + // If the pool is disabled, this is None. + inner: Option<Arc<Mutex<PoolInner<T>>>>, +} + +// Before using a pooled connection, make sure the sender is not dead. +// +// This is a trait to allow the `client::pool::tests` to work for `i32`. +// +// See https://github.com/hyperium/hyper/issues/1429 +pub(super) trait Poolable: Unpin + Send + Sized + 'static { + fn is_open(&self) -> bool; + /// Reserve this connection. + /// + /// Allows for HTTP/2 to return a shared reservation. + fn reserve(self) -> Reservation<Self>; + fn can_share(&self) -> bool; +} + +/// When checking out a pooled connection, it might be that the connection +/// only supports a single reservation, or it might be usable for many. +/// +/// Specifically, HTTP/1 requires a unique reservation, but HTTP/2 can be +/// used for multiple requests. +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) enum Reservation<T> { + /// This connection could be used multiple times, the first one will be + /// reinserted into the `idle` pool, and the second will be given to + /// the `Checkout`. + Shared(T, T), + /// This connection requires unique access. It will be returned after + /// use is complete. + Unique(T), +} + +/// Simple type alias in case the key type needs to be adjusted. +pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>; + +struct PoolInner<T> { + // A flag that a connection is being established, and the connection + // should be shared. This prevents making multiple HTTP/2 connections + // to the same host. + connecting: HashSet<Key>, + // These are internal Conns sitting in the event loop in the KeepAlive + // state, waiting to receive a new Request to send on the socket. + idle: HashMap<Key, Vec<Idle<T>>>, + max_idle_per_host: usize, + // These are outstanding Checkouts that are waiting for a socket to be + // able to send a Request one. This is used when "racing" for a new + // connection. + // + // The Client starts 2 tasks, 1 to connect a new socket, and 1 to wait + // for the Pool to receive an idle Conn. When a Conn becomes idle, + // this list is checked for any parked Checkouts, and tries to notify + // them that the Conn could be used instead of waiting for a brand new + // connection. + waiters: HashMap<Key, VecDeque<oneshot::Sender<T>>>, + // A oneshot channel is used to allow the interval to be notified when + // the Pool completely drops. That way, the interval can cancel immediately. + #[cfg(feature = "runtime")] + idle_interval_ref: Option<oneshot::Sender<crate::common::Never>>, + #[cfg(feature = "runtime")] + exec: Exec, + timeout: Option<Duration>, +} + +// This is because `Weak::new()` *allocates* space for `T`, even if it +// doesn't need it! +struct WeakOpt<T>(Option<Weak<T>>); + +#[derive(Clone, Copy, Debug)] +pub(super) struct Config { + pub(super) idle_timeout: Option<Duration>, + pub(super) max_idle_per_host: usize, +} + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + self.max_idle_per_host > 0 + } +} + +impl<T> Pool<T> { + pub fn new(config: Config, __exec: &Exec) -> Pool<T> { + let inner = if config.is_enabled() { + Some(Arc::new(Mutex::new(PoolInner { + connecting: HashSet::new(), + idle: HashMap::new(), + #[cfg(feature = "runtime")] + idle_interval_ref: None, + max_idle_per_host: config.max_idle_per_host, + waiters: HashMap::new(), + #[cfg(feature = "runtime")] + exec: __exec.clone(), + timeout: config.idle_timeout, + }))) + } else { + None + }; + + Pool { inner } + } + + fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + #[cfg(test)] + pub(super) fn no_timer(&self) { + // Prevent an actual interval from being created for this pool... + #[cfg(feature = "runtime")] + { + let mut inner = self.inner.as_ref().unwrap().lock().unwrap(); + assert!(inner.idle_interval_ref.is_none(), "timer already spawned"); + let (tx, _) = oneshot::channel(); + inner.idle_interval_ref = Some(tx); + } + } +} + +impl<T: Poolable> Pool<T> { + /// Returns a `Checkout` which is a future that resolves if an idle + /// connection becomes available. + pub fn checkout(&self, key: Key) -> Checkout<T> { + Checkout { + key, + pool: self.clone(), + waiter: None, + } + } + + /// Ensure that there is only ever 1 connecting task for HTTP/2 + /// connections. This does nothing for HTTP/1. + pub(super) fn connecting(&self, key: &Key, ver: Ver) -> Option<Connecting<T>> { + if ver == Ver::Http2 { + if let Some(ref enabled) = self.inner { + let mut inner = enabled.lock().unwrap(); + return if inner.connecting.insert(key.clone()) { + let connecting = Connecting { + key: key.clone(), + pool: WeakOpt::downgrade(enabled), + }; + Some(connecting) + } else { + trace!("HTTP/2 connecting already in progress for {:?}", key); + None + }; + } + } + + // else + Some(Connecting { + key: key.clone(), + // in HTTP/1's case, there is never a lock, so we don't + // need to do anything in Drop. + pool: WeakOpt::none(), + }) + } + + #[cfg(test)] + fn locked(&self) -> std::sync::MutexGuard<'_, PoolInner<T>> { + self.inner.as_ref().expect("enabled").lock().expect("lock") + } + + /* Used in client/tests.rs... + #[cfg(feature = "runtime")] + #[cfg(test)] + pub(super) fn h1_key(&self, s: &str) -> Key { + Arc::new(s.to_string()) + } + + #[cfg(feature = "runtime")] + #[cfg(test)] + pub(super) fn idle_count(&self, key: &Key) -> usize { + self + .locked() + .idle + .get(key) + .map(|list| list.len()) + .unwrap_or(0) + } + */ + + pub(super) fn pooled(&self, mut connecting: Connecting<T>, value: T) -> Pooled<T> { + let (value, pool_ref) = if let Some(ref enabled) = self.inner { + match value.reserve() { + Reservation::Shared(to_insert, to_return) => { + let mut inner = enabled.lock().unwrap(); + inner.put(connecting.key.clone(), to_insert, enabled); + // Do this here instead of Drop for Connecting because we + // already have a lock, no need to lock the mutex twice. + inner.connected(&connecting.key); + // prevent the Drop of Connecting from repeating inner.connected() + connecting.pool = WeakOpt::none(); + + // Shared reservations don't need a reference to the pool, + // since the pool always keeps a copy. + (to_return, WeakOpt::none()) + } + Reservation::Unique(value) => { + // Unique reservations must take a reference to the pool + // since they hope to reinsert once the reservation is + // completed + (value, WeakOpt::downgrade(enabled)) + } + } + } else { + // If pool is not enabled, skip all the things... + + // The Connecting should have had no pool ref + debug_assert!(connecting.pool.upgrade().is_none()); + + (value, WeakOpt::none()) + }; + Pooled { + key: connecting.key.clone(), + is_reused: false, + pool: pool_ref, + value: Some(value), + } + } + + fn reuse(&self, key: &Key, value: T) -> Pooled<T> { + debug!("reuse idle connection for {:?}", key); + // TODO: unhack this + // In Pool::pooled(), which is used for inserting brand new connections, + // there's some code that adjusts the pool reference taken depending + // on if the Reservation can be shared or is unique. By the time + // reuse() is called, the reservation has already been made, and + // we just have the final value, without knowledge of if this is + // unique or shared. So, the hack is to just assume Ver::Http2 means + // shared... :( + let mut pool_ref = WeakOpt::none(); + if !value.can_share() { + if let Some(ref enabled) = self.inner { + pool_ref = WeakOpt::downgrade(enabled); + } + } + + Pooled { + is_reused: true, + key: key.clone(), + pool: pool_ref, + value: Some(value), + } + } +} + +/// Pop off this list, looking for a usable connection that hasn't expired. +struct IdlePopper<'a, T> { + key: &'a Key, + list: &'a mut Vec<Idle<T>>, +} + +impl<'a, T: Poolable + 'a> IdlePopper<'a, T> { + fn pop(self, expiration: &Expiration) -> Option<Idle<T>> { + while let Some(entry) = self.list.pop() { + // If the connection has been closed, or is older than our idle + // timeout, simply drop it and keep looking... + if !entry.value.is_open() { + trace!("removing closed connection for {:?}", self.key); + continue; + } + // TODO: Actually, since the `idle` list is pushed to the end always, + // that would imply that if *this* entry is expired, then anything + // "earlier" in the list would *have* to be expired also... Right? + // + // In that case, we could just break out of the loop and drop the + // whole list... + if expiration.expires(entry.idle_at) { + trace!("removing expired connection for {:?}", self.key); + continue; + } + + let value = match entry.value.reserve() { + Reservation::Shared(to_reinsert, to_checkout) => { + self.list.push(Idle { + idle_at: Instant::now(), + value: to_reinsert, + }); + to_checkout + } + Reservation::Unique(unique) => unique, + }; + + return Some(Idle { + idle_at: entry.idle_at, + value, + }); + } + + None + } +} + +impl<T: Poolable> PoolInner<T> { + fn put(&mut self, key: Key, value: T, __pool_ref: &Arc<Mutex<PoolInner<T>>>) { + if value.can_share() && self.idle.contains_key(&key) { + trace!("put; existing idle HTTP/2 connection for {:?}", key); + return; + } + trace!("put; add idle connection for {:?}", key); + let mut remove_waiters = false; + let mut value = Some(value); + if let Some(waiters) = self.waiters.get_mut(&key) { + while let Some(tx) = waiters.pop_front() { + if !tx.is_canceled() { + let reserved = value.take().expect("value already sent"); + let reserved = match reserved.reserve() { + Reservation::Shared(to_keep, to_send) => { + value = Some(to_keep); + to_send + } + Reservation::Unique(uniq) => uniq, + }; + match tx.send(reserved) { + Ok(()) => { + if value.is_none() { + break; + } else { + continue; + } + } + Err(e) => { + value = Some(e); + } + } + } + + trace!("put; removing canceled waiter for {:?}", key); + } + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(&key); + } + + match value { + Some(value) => { + // borrow-check scope... + { + let idle_list = self.idle.entry(key.clone()).or_insert_with(Vec::new); + if self.max_idle_per_host <= idle_list.len() { + trace!("max idle per host for {:?}, dropping connection", key); + return; + } + + debug!("pooling idle connection for {:?}", key); + idle_list.push(Idle { + value, + idle_at: Instant::now(), + }); + } + + #[cfg(feature = "runtime")] + { + self.spawn_idle_interval(__pool_ref); + } + } + None => trace!("put; found waiter for {:?}", key), + } + } + + /// A `Connecting` task is complete. Not necessarily successfully, + /// but the lock is going away, so clean up. + fn connected(&mut self, key: &Key) { + let existed = self.connecting.remove(key); + debug_assert!(existed, "Connecting dropped, key not in pool.connecting"); + // cancel any waiters. if there are any, it's because + // this Connecting task didn't complete successfully. + // those waiters would never receive a connection. + self.waiters.remove(key); + } + + #[cfg(feature = "runtime")] + fn spawn_idle_interval(&mut self, pool_ref: &Arc<Mutex<PoolInner<T>>>) { + let (dur, rx) = { + if self.idle_interval_ref.is_some() { + return; + } + + if let Some(dur) = self.timeout { + let (tx, rx) = oneshot::channel(); + self.idle_interval_ref = Some(tx); + (dur, rx) + } else { + return; + } + }; + + let interval = IdleTask { + interval: tokio::time::interval(dur), + pool: WeakOpt::downgrade(pool_ref), + pool_drop_notifier: rx, + }; + + self.exec.execute(interval); + } +} + +impl<T> PoolInner<T> { + /// Any `FutureResponse`s that were created will have made a `Checkout`, + /// and possibly inserted into the pool that it is waiting for an idle + /// connection. If a user ever dropped that future, we need to clean out + /// those parked senders. + fn clean_waiters(&mut self, key: &Key) { + let mut remove_waiters = false; + if let Some(waiters) = self.waiters.get_mut(key) { + waiters.retain(|tx| !tx.is_canceled()); + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(key); + } + } +} + +#[cfg(feature = "runtime")] +impl<T: Poolable> PoolInner<T> { + /// This should *only* be called by the IdleTask + fn clear_expired(&mut self) { + let dur = self.timeout.expect("interval assumes timeout"); + + let now = Instant::now(); + //self.last_idle_check_at = now; + + self.idle.retain(|key, values| { + values.retain(|entry| { + if !entry.value.is_open() { + trace!("idle interval evicting closed for {:?}", key); + return false; + } + if now - entry.idle_at > dur { + trace!("idle interval evicting expired for {:?}", key); + return false; + } + + // Otherwise, keep this value... + true + }); + + // returning false evicts this key/val + !values.is_empty() + }); + } +} + +impl<T> Clone for Pool<T> { + fn clone(&self) -> Pool<T> { + Pool { + inner: self.inner.clone(), + } + } +} + +/// A wrapped poolable value that tries to reinsert to the Pool on Drop. +// Note: The bounds `T: Poolable` is needed for the Drop impl. +pub(super) struct Pooled<T: Poolable> { + value: Option<T>, + is_reused: bool, + key: Key, + pool: WeakOpt<Mutex<PoolInner<T>>>, +} + +impl<T: Poolable> Pooled<T> { + pub fn is_reused(&self) -> bool { + self.is_reused + } + + pub fn is_pool_enabled(&self) -> bool { + self.pool.0.is_some() + } + + fn as_ref(&self) -> &T { + self.value.as_ref().expect("not dropped") + } + + fn as_mut(&mut self) -> &mut T { + self.value.as_mut().expect("not dropped") + } +} + +impl<T: Poolable> Deref for Pooled<T> { + type Target = T; + fn deref(&self) -> &T { + self.as_ref() + } +} + +impl<T: Poolable> DerefMut for Pooled<T> { + fn deref_mut(&mut self) -> &mut T { + self.as_mut() + } +} + +impl<T: Poolable> Drop for Pooled<T> { + fn drop(&mut self) { + if let Some(value) = self.value.take() { + if !value.is_open() { + // If we *already* know the connection is done here, + // it shouldn't be re-inserted back into the pool. + return; + } + + if let Some(pool) = self.pool.upgrade() { + if let Ok(mut inner) = pool.lock() { + inner.put(self.key.clone(), value, &pool); + } + } else if !value.can_share() { + trace!("pool dropped, dropping pooled ({:?})", self.key); + } + // Ver::Http2 is already in the Pool (or dead), so we wouldn't + // have an actual reference to the Pool. + } + } +} + +impl<T: Poolable> fmt::Debug for Pooled<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pooled").field("key", &self.key).finish() + } +} + +struct Idle<T> { + idle_at: Instant, + value: T, +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Checkout<T> { + key: Key, + pool: Pool<T>, + waiter: Option<oneshot::Receiver<T>>, +} + +impl<T: Poolable> Checkout<T> { + fn poll_waiter( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<Pooled<T>>>> { + static CANCELED: &str = "pool checkout failed"; + if let Some(mut rx) = self.waiter.take() { + match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(value)) => { + if value.is_open() { + Poll::Ready(Some(Ok(self.pool.reuse(&self.key, value)))) + } else { + Poll::Ready(Some(Err(crate::Error::new_canceled().with(CANCELED)))) + } + } + Poll::Pending => { + self.waiter = Some(rx); + Poll::Pending + } + Poll::Ready(Err(_canceled)) => { + Poll::Ready(Some(Err(crate::Error::new_canceled().with(CANCELED)))) + } + } + } else { + Poll::Ready(None) + } + } + + fn checkout(&mut self, cx: &mut task::Context<'_>) -> Option<Pooled<T>> { + let entry = { + let mut inner = self.pool.inner.as_ref()?.lock().unwrap(); + let expiration = Expiration::new(inner.timeout); + let maybe_entry = inner.idle.get_mut(&self.key).and_then(|list| { + trace!("take? {:?}: expiration = {:?}", self.key, expiration.0); + // A block to end the mutable borrow on list, + // so the map below can check is_empty() + { + let popper = IdlePopper { + key: &self.key, + list, + }; + popper.pop(&expiration) + } + .map(|e| (e, list.is_empty())) + }); + + let (entry, empty) = if let Some((e, empty)) = maybe_entry { + (Some(e), empty) + } else { + // No entry found means nuke the list for sure. + (None, true) + }; + if empty { + //TODO: This could be done with the HashMap::entry API instead. + inner.idle.remove(&self.key); + } + + if entry.is_none() && self.waiter.is_none() { + let (tx, mut rx) = oneshot::channel(); + trace!("checkout waiting for idle connection: {:?}", self.key); + inner + .waiters + .entry(self.key.clone()) + .or_insert_with(VecDeque::new) + .push_back(tx); + + // register the waker with this oneshot + assert!(Pin::new(&mut rx).poll(cx).is_pending()); + self.waiter = Some(rx); + } + + entry + }; + + entry.map(|e| self.pool.reuse(&self.key, e.value)) + } +} + +impl<T: Poolable> Future for Checkout<T> { + type Output = crate::Result<Pooled<T>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + if let Some(pooled) = ready!(self.poll_waiter(cx)?) { + return Poll::Ready(Ok(pooled)); + } + + if let Some(pooled) = self.checkout(cx) { + Poll::Ready(Ok(pooled)) + } else if !self.pool.is_enabled() { + Poll::Ready(Err(crate::Error::new_canceled().with("pool is disabled"))) + } else { + // There's a new waiter, already registered in self.checkout() + debug_assert!(self.waiter.is_some()); + Poll::Pending + } + } +} + +impl<T> Drop for Checkout<T> { + fn drop(&mut self) { + if self.waiter.take().is_some() { + trace!("checkout dropped for {:?}", self.key); + if let Some(Ok(mut inner)) = self.pool.inner.as_ref().map(|i| i.lock()) { + inner.clean_waiters(&self.key); + } + } + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Connecting<T: Poolable> { + key: Key, + pool: WeakOpt<Mutex<PoolInner<T>>>, +} + +impl<T: Poolable> Connecting<T> { + pub(super) fn alpn_h2(self, pool: &Pool<T>) -> Option<Self> { + debug_assert!( + self.pool.0.is_none(), + "Connecting::alpn_h2 but already Http2" + ); + + pool.connecting(&self.key, Ver::Http2) + } +} + +impl<T: Poolable> Drop for Connecting<T> { + fn drop(&mut self) { + if let Some(pool) = self.pool.upgrade() { + // No need to panic on drop, that could abort! + if let Ok(mut inner) = pool.lock() { + inner.connected(&self.key); + } + } + } +} + +struct Expiration(Option<Duration>); + +impl Expiration { + fn new(dur: Option<Duration>) -> Expiration { + Expiration(dur) + } + + fn expires(&self, instant: Instant) -> bool { + match self.0 { + Some(timeout) => instant.elapsed() > timeout, + None => false, + } + } +} + +#[cfg(feature = "runtime")] +struct IdleTask<T> { + interval: Interval, + pool: WeakOpt<Mutex<PoolInner<T>>>, + // This allows the IdleTask to be notified as soon as the entire + // Pool is fully dropped, and shutdown. This channel is never sent on, + // but Err(Canceled) will be received when the Pool is dropped. + pool_drop_notifier: oneshot::Receiver<crate::common::Never>, +} + +#[cfg(feature = "runtime")] +impl<T: Poolable + 'static> Future for IdleTask<T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match Pin::new(&mut self.pool_drop_notifier).poll(cx) { + Poll::Ready(Ok(n)) => match n {}, + Poll::Pending => (), + Poll::Ready(Err(_canceled)) => { + trace!("pool closed, canceling idle interval"); + return Poll::Ready(()); + } + } + + ready!(self.interval.poll_tick(cx)); + + if let Some(inner) = self.pool.upgrade() { + if let Ok(mut inner) = inner.lock() { + trace!("idle interval checking for expired"); + inner.clear_expired(); + continue; + } + } + return Poll::Ready(()); + } + } +} + +impl<T> WeakOpt<T> { + fn none() -> Self { + WeakOpt(None) + } + + fn downgrade(arc: &Arc<T>) -> Self { + WeakOpt(Some(Arc::downgrade(arc))) + } + + fn upgrade(&self) -> Option<Arc<T>> { + self.0.as_ref().and_then(Weak::upgrade) + } +} + +#[cfg(test)] +mod tests { + use std::task::Poll; + use std::time::Duration; + + use super::{Connecting, Key, Pool, Poolable, Reservation, WeakOpt}; + use crate::common::{task, Exec, Future, Pin}; + + /// Test unique reservations. + #[derive(Debug, PartialEq, Eq)] + struct Uniq<T>(T); + + impl<T: Send + 'static + Unpin> Poolable for Uniq<T> { + fn is_open(&self) -> bool { + true + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + fn c<T: Poolable>(key: Key) -> Connecting<T> { + Connecting { + key, + pool: WeakOpt::none(), + } + } + + fn host_key(s: &str) -> Key { + (http::uri::Scheme::HTTP, s.parse().expect("host key")) + } + + fn pool_no_timer<T>() -> Pool<T> { + pool_max_idle_no_timer(::std::usize::MAX) + } + + fn pool_max_idle_no_timer<T>(max_idle: usize) -> Pool<T> { + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(100)), + max_idle_per_host: max_idle, + }, + &Exec::Default, + ); + pool.no_timer(); + pool + } + + #[tokio::test] + async fn test_pool_checkout_smoke() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + + match pool.checkout(key).await { + Ok(pooled) => assert_eq!(*pooled, Uniq(41)), + Err(_) => panic!("not ready"), + }; + } + + /// Helper to check if the future is ready after polling once. + struct PollOnce<'a, F>(&'a mut F); + + impl<F, T, U> Future for PollOnce<'_, F> + where + F: Future<Output = Result<T, U>> + Unpin, + { + type Output = Option<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Some(())), + Poll::Ready(Err(_)) => Poll::Ready(Some(())), + Poll::Pending => Poll::Ready(None), + } + } + } + + #[tokio::test] + async fn test_pool_checkout_returns_none_if_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + tokio::time::delay_for(pool.locked().timeout.unwrap()).await; + let mut checkout = pool.checkout(key); + let poll_once = PollOnce(&mut checkout); + let is_not_ready = poll_once.await.is_none(); + assert!(is_not_ready); + } + + #[cfg(feature = "runtime")] + #[tokio::test] + async fn test_pool_checkout_removes_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + tokio::time::delay_for(pool.locked().timeout.unwrap()).await; + + let mut checkout = pool.checkout(key.clone()); + let poll_once = PollOnce(&mut checkout); + // checkout.await should clean out the expired + poll_once.await; + assert!(pool.locked().idle.get(&key).is_none()); + } + + #[test] + fn test_pool_max_idle_per_host() { + let pool = pool_max_idle_no_timer(2); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + // pooled and dropped 3, max_idle should only allow 2 + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(2) + ); + } + + #[cfg(feature = "runtime")] + #[tokio::test] + async fn test_pool_timer_removes_expired() { + let _ = pretty_env_logger::try_init(); + tokio::time::pause(); + + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(10)), + max_idle_per_host: std::usize::MAX, + }, + &Exec::Default, + ); + + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + + // Let the timer tick passed the expiration... + tokio::time::advance(Duration::from_millis(30)).await; + // Yield so the Interval can reap... + tokio::task::yield_now().await; + + assert!(pool.locked().idle.get(&key).is_none()); + } + + #[tokio::test] + async fn test_pool_checkout_task_unparked() { + use futures_util::future::join; + use futures_util::FutureExt; + + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + let checkout = join(pool.checkout(key), async { + // the checkout future will park first, + // and then this lazy future will be polled, which will insert + // the pooled back into the pool + // + // this test makes sure that doing so will unpark the checkout + drop(pooled); + }) + .map(|(entry, _)| entry); + + assert_eq!(*checkout.await.unwrap(), Uniq(41)); + } + + #[tokio::test] + async fn test_pool_checkout_drop_cleans_up_waiters() { + let pool = pool_no_timer::<Uniq<i32>>(); + let key = host_key("foo"); + + let mut checkout1 = pool.checkout(key.clone()); + let mut checkout2 = pool.checkout(key.clone()); + + let poll_once1 = PollOnce(&mut checkout1); + let poll_once2 = PollOnce(&mut checkout2); + + // first poll needed to get into Pool's parked + poll_once1.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + poll_once2.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 2); + + // on drop, clean up Pool + drop(checkout1); + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + + drop(checkout2); + assert!(pool.locked().waiters.get(&key).is_none()); + } + + #[derive(Debug)] + struct CanClose { + val: i32, + closed: bool, + } + + impl Poolable for CanClose { + fn is_open(&self) -> bool { + !self.closed + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + #[test] + fn pooled_drop_if_closed_doesnt_reinsert() { + let pool = pool_no_timer(); + let key = host_key("foo"); + pool.pooled( + c(key.clone()), + CanClose { + val: 57, + closed: true, + }, + ); + + assert!(!pool.locked().idle.contains_key(&key)); + } +} diff --git a/third_party/rust/hyper/src/client/service.rs b/third_party/rust/hyper/src/client/service.rs new file mode 100644 index 0000000000..ef3a9babb2 --- /dev/null +++ b/third_party/rust/hyper/src/client/service.rs @@ -0,0 +1,86 @@ +//! Utilities used to interact with the Tower ecosystem. +//! +//! This module provides `Connect` which hook-ins into the Tower ecosystem. + +use std::error::Error as StdError; +use std::future::Future; +use std::marker::PhantomData; + +use super::conn::{Builder, SendRequest}; +use crate::{ + body::Payload, + common::{task, Pin, Poll}, + service::{MakeConnection, Service}, +}; + +/// Creates a connection via `SendRequest`. +/// +/// This accepts a `hyper::client::conn::Builder` and provides +/// a `MakeService` implementation to create connections from some +/// target `T`. +#[derive(Debug)] +pub struct Connect<C, B, T> { + inner: C, + builder: Builder, + _pd: PhantomData<fn(T, B)>, +} + +impl<C, B, T> Connect<C, B, T> { + /// Create a new `Connect` with some inner connector `C` and a connection + /// builder. + pub fn new(inner: C, builder: Builder) -> Self { + Self { + inner, + builder, + _pd: PhantomData, + } + } +} + +impl<C, B, T> Service<T> for Connect<C, B, T> +where + C: MakeConnection<T>, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into<Box<dyn StdError + Send + Sync>> + Send, + B: Payload + Unpin + 'static, + B::Data: Unpin, +{ + type Response = SendRequest<B>; + type Error = crate::Error; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner + .poll_ready(cx) + .map_err(|e| crate::Error::new(crate::error::Kind::Connect).with(e.into())) + } + + fn call(&mut self, req: T) -> Self::Future { + let builder = self.builder.clone(); + let io = self.inner.make_connection(req); + + let fut = async move { + match io.await { + Ok(io) => match builder.handshake(io).await { + Ok((sr, conn)) => { + builder.exec.execute(async move { + if let Err(e) = conn.await { + debug!("connection error: {:?}", e); + } + }); + Ok(sr) + } + Err(e) => Err(e), + }, + Err(e) => { + let err = crate::Error::new(crate::error::Kind::Connect).with(e.into()); + Err(err) + } + } + }; + + Box::pin(fut) + } +} diff --git a/third_party/rust/hyper/src/client/tests.rs b/third_party/rust/hyper/src/client/tests.rs new file mode 100644 index 0000000000..e955cb60c6 --- /dev/null +++ b/third_party/rust/hyper/src/client/tests.rs @@ -0,0 +1,286 @@ +use std::io; + +use futures_util::future; +use tokio::net::TcpStream; + +use super::Client; + +#[tokio::test] +async fn client_connect_uri_argument() { + let connector = tower_util::service_fn(|dst: http::Uri| { + assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP)); + assert_eq!(dst.host(), Some("example.local")); + assert_eq!(dst.port(), None); + assert_eq!(dst.path(), "/", "path should be removed"); + + future::err::<TcpStream, _>(io::Error::new(io::ErrorKind::Other, "expect me")) + }); + + let client = Client::builder().build::<_, crate::Body>(connector); + let _ = client + .get("http://example.local/and/a/path".parse().unwrap()) + .await + .expect_err("response should fail"); +} + +/* +// FIXME: re-implement tests with `async/await` +#[test] +fn retryable_request() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + let sock1 = connector.mock("http://mock.local"); + let sock2 = connector.mock("http://mock.local"); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + { + + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res1 = client.request(req); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + } + drop(sock1); + + let req = Request::builder() + .uri("http://mock.local/b") + .body(Default::default()) + .unwrap(); + let res2 = client.request(req) + .map(|res| { + assert_eq!(res.status().as_u16(), 222); + }); + let srv2 = poll_fn(|| { + try_ready!(sock2.read(&mut [0u8; 512])); + try_ready!(sock2.write(b"HTTP/1.1 222 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + + rt.block_on(res2.join(srv2)).expect("res2"); +} + +#[test] +fn conn_reset_after_write() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + let sock1 = connector.mock("http://mock.local"); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + { + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res1 = client.request(req); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + } + + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res2 = client.request(req); + let mut sock1 = Some(sock1); + let srv2 = poll_fn(|| { + // We purposefully keep the socket open until the client + // has written the second request, and THEN disconnect. + // + // Not because we expect servers to be jerks, but to trigger + // state where we write on an assumedly good connection, and + // only reset the close AFTER we wrote bytes. + try_ready!(sock1.as_mut().unwrap().read(&mut [0u8; 512])); + sock1.take(); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + let err = rt.block_on(res2.join(srv2)).expect_err("res2"); + assert!(err.is_incomplete_message(), "{:?}", err); +} + +#[test] +fn checkout_win_allows_connect_future_to_be_pooled() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let (tx, rx) = oneshot::channel::<()>(); + let sock1 = connector.mock("http://mock.local"); + let sock2 = connector.mock_fut("http://mock.local", rx); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + let uri = "http://mock.local/a".parse::<crate::Uri>().expect("uri parse"); + + // First request just sets us up to have a connection able to be put + // back in the pool. *However*, it doesn't insert immediately. The + // body has 1 pending byte, and we will only drain in request 2, once + // the connect future has been started. + let mut body = { + let res1 = client.get(uri.clone()) + .map(|res| res.into_body().concat2()); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + // Chunked is used so as to force 2 body reads. + try_ready!(sock1.write(b"\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + 1\r\nx\r\n\ + 0\r\n\r\n\ + ")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + + rt.block_on(res1.join(srv1)).expect("res1").0 + }; + + + // The second request triggers the only mocked connect future, but then + // the drained body allows the first socket to go back to the pool, + // "winning" the checkout race. + { + let res2 = client.get(uri.clone()); + let drain = poll_fn(move || { + body.poll() + }); + let srv2 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nx")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + + rt.block_on(res2.join(drain).join(srv2)).expect("res2"); + } + + // "Release" the mocked connect future, and let the runtime spin once so + // it's all setup... + { + let mut tx = Some(tx); + let client = &client; + let key = client.pool.h1_key("http://mock.local"); + let mut tick_cnt = 0; + let fut = poll_fn(move || { + tx.take(); + + if client.pool.idle_count(&key) == 0 { + tick_cnt += 1; + assert!(tick_cnt < 10, "ticked too many times waiting for idle"); + trace!("no idle yet; tick count: {}", tick_cnt); + ::futures::task::current().notify(); + Ok(Async::NotReady) + } else { + Ok::<_, ()>(Async::Ready(())) + } + }); + rt.block_on(fut).unwrap(); + } + + // Third request just tests out that the "loser" connection was pooled. If + // it isn't, this will panic since the MockConnector doesn't have any more + // mocks to give out. + { + let res3 = client.get(uri); + let srv3 = poll_fn(|| { + try_ready!(sock2.read(&mut [0u8; 512])); + try_ready!(sock2.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv3 poll_fn error: {}", e)); + + rt.block_on(res3.join(srv3)).expect("res3"); + } +} + +#[cfg(feature = "nightly")] +#[bench] +fn bench_http1_get_0b(b: &mut test::Bencher) { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let client = Client::builder() + .build::<_, crate::Body>(connector.clone()); + + client.pool.no_timer(); + + let uri = Uri::from_static("http://mock.local/a"); + + b.iter(move || { + let sock1 = connector.mock("http://mock.local"); + let res1 = client + .get(uri.clone()) + .and_then(|res| { + res.into_body().for_each(|_| Ok(())) + }); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + }); +} + +#[cfg(feature = "nightly")] +#[bench] +fn bench_http1_get_10b(b: &mut test::Bencher) { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let client = Client::builder() + .build::<_, crate::Body>(connector.clone()); + + client.pool.no_timer(); + + let uri = Uri::from_static("http://mock.local/a"); + + b.iter(move || { + let sock1 = connector.mock("http://mock.local"); + let res1 = client + .get(uri.clone()) + .and_then(|res| { + res.into_body().for_each(|_| Ok(())) + }); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n0123456789")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + }); +} +*/ diff --git a/third_party/rust/hyper/src/common/buf.rs b/third_party/rust/hyper/src/common/buf.rs new file mode 100644 index 0000000000..8f71b7bbad --- /dev/null +++ b/third_party/rust/hyper/src/common/buf.rs @@ -0,0 +1,72 @@ +use std::collections::VecDeque; +use std::io::IoSlice; + +use bytes::Buf; + +pub(crate) struct BufList<T> { + bufs: VecDeque<T>, +} + +impl<T: Buf> BufList<T> { + pub(crate) fn new() -> BufList<T> { + BufList { + bufs: VecDeque::new(), + } + } + + #[inline] + pub(crate) fn push(&mut self, buf: T) { + debug_assert!(buf.has_remaining()); + self.bufs.push_back(buf); + } + + #[inline] + pub(crate) fn bufs_cnt(&self) -> usize { + self.bufs.len() + } +} + +impl<T: Buf> Buf for BufList<T> { + #[inline] + fn remaining(&self) -> usize { + self.bufs.iter().map(|buf| buf.remaining()).sum() + } + + #[inline] + fn bytes(&self) -> &[u8] { + self.bufs.front().map(Buf::bytes).unwrap_or_default() + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + { + let front = &mut self.bufs[0]; + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } + self.bufs.pop_front(); + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + if dst.is_empty() { + return 0; + } + let mut vecs = 0; + for buf in &self.bufs { + vecs += buf.bytes_vectored(&mut dst[vecs..]); + if vecs == dst.len() { + break; + } + } + vecs + } +} diff --git a/third_party/rust/hyper/src/common/drain.rs b/third_party/rust/hyper/src/common/drain.rs new file mode 100644 index 0000000000..7abb9f9ded --- /dev/null +++ b/third_party/rust/hyper/src/common/drain.rs @@ -0,0 +1,232 @@ +use std::mem; + +use pin_project::pin_project; +use tokio::sync::{mpsc, watch}; + +use super::{task, Future, Never, Pin, Poll}; + +// Sentinel value signaling that the watch is still open +#[derive(Clone, Copy)] +enum Action { + Open, + // Closed isn't sent via the `Action` type, but rather once + // the watch::Sender is dropped. +} + +pub fn channel() -> (Signal, Watch) { + let (tx, rx) = watch::channel(Action::Open); + let (drained_tx, drained_rx) = mpsc::channel(1); + ( + Signal { + drained_rx, + _tx: tx, + }, + Watch { drained_tx, rx }, + ) +} + +pub struct Signal { + drained_rx: mpsc::Receiver<Never>, + _tx: watch::Sender<Action>, +} + +pub struct Draining { + drained_rx: mpsc::Receiver<Never>, +} + +#[derive(Clone)] +pub struct Watch { + drained_tx: mpsc::Sender<Never>, + rx: watch::Receiver<Action>, +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct Watching<F, FN> { + #[pin] + future: F, + state: State<FN>, + watch: Watch, +} + +enum State<F> { + Watch(F), + Draining, +} + +impl Signal { + pub fn drain(self) -> Draining { + // Simply dropping `self.tx` will signal the watchers + Draining { + drained_rx: self.drained_rx, + } + } +} + +impl Future for Draining { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match ready!(self.drained_rx.poll_recv(cx)) { + Some(never) => match never {}, + None => Poll::Ready(()), + } + } +} + +impl Watch { + pub fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> + where + F: Future, + FN: FnOnce(Pin<&mut F>), + { + Watching { + future, + state: State::Watch(on_drain), + watch: self, + } + } +} + +impl<F, FN> Future for Watching<F, FN> +where + F: Future, + FN: FnOnce(Pin<&mut F>), +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + match mem::replace(me.state, State::Draining) { + State::Watch(on_drain) => { + match me.watch.rx.poll_recv_ref(cx) { + Poll::Ready(None) => { + // Drain has been triggered! + on_drain(me.future.as_mut()); + } + Poll::Ready(Some(_ /*State::Open*/)) | Poll::Pending => { + *me.state = State::Watch(on_drain); + return me.future.poll(cx); + } + } + } + State::Draining => return me.future.poll(cx), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestMe { + draining: bool, + finished: bool, + poll_cnt: usize, + } + + impl Future for TestMe { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_cnt += 1; + if self.finished { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + + #[test] + fn watch() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + let fut = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let mut watch = rx.watch(fut, |mut fut| { + fut.draining = true; + }); + + assert_eq!(watch.future.poll_cnt, 0); + + // First poll should poll the inner future + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 1); + + // Second poll should poll the inner future again + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 2); + + let mut draining = tx.drain(); + // Drain signaled, but needs another poll to be noticed. + assert!(!watch.future.draining); + assert_eq!(watch.future.poll_cnt, 2); + + // Now, poll after drain has been signaled. + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 3); + assert!(watch.future.draining); + + // Draining is not ready until watcher completes + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // Finishing up the watch future + watch.future.finished = true; + assert!(Pin::new(&mut watch).poll(cx).is_ready()); + assert_eq!(watch.future.poll_cnt, 4); + drop(watch); + + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }) + } + + #[test] + fn watch_clones() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + + let fut1 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + let fut2 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let watch1 = rx.clone().watch(fut1, |mut fut| { + fut.draining = true; + }); + let watch2 = rx.watch(fut2, |mut fut| { + fut.draining = true; + }); + + let mut draining = tx.drain(); + + // Still 2 outstanding watchers + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // drop 1 for whatever reason + drop(watch1); + + // Still not ready, 1 other watcher still pending + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + drop(watch2); + + // Now all watchers are gone, draining is complete + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }); + } +} diff --git a/third_party/rust/hyper/src/common/exec.rs b/third_party/rust/hyper/src/common/exec.rs new file mode 100644 index 0000000000..94ad4610a2 --- /dev/null +++ b/third_party/rust/hyper/src/common/exec.rs @@ -0,0 +1,111 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use crate::body::{Body, Payload}; +use crate::proto::h2::server::H2Stream; +use crate::server::conn::spawn_all::{NewSvcTask, Watcher}; +use crate::service::HttpService; + +/// An executor of futures. +pub trait Executor<Fut> { + /// Place the future into the executor to be run. + fn execute(&self, fut: Fut); +} + +pub trait H2Exec<F, B: Payload>: Clone { + fn execute_h2stream(&mut self, fut: H2Stream<F, B>); +} + +pub trait NewSvcExec<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>>: Clone { + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>); +} + +pub type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; + +// Either the user provides an executor for background tasks, or we use +// `tokio::spawn`. +#[derive(Clone)] +pub enum Exec { + Default, + Executor(Arc<dyn Executor<BoxSendFuture> + Send + Sync>), +} + +// ===== impl Exec ===== + +impl Exec { + pub(crate) fn execute<F>(&self, fut: F) + where + F: Future<Output = ()> + Send + 'static, + { + match *self { + Exec::Default => { + #[cfg(feature = "tcp")] + { + tokio::task::spawn(fut); + } + #[cfg(not(feature = "tcp"))] + { + // If no runtime, we need an executor! + panic!("executor must be set") + } + } + Exec::Executor(ref e) => { + e.execute(Box::pin(fut)); + } + } + } +} + +impl fmt::Debug for Exec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Exec").finish() + } +} + +impl<F, B> H2Exec<F, B> for Exec +where + H2Stream<F, B>: Future<Output = ()> + Send + 'static, + B: Payload, +{ + fn execute_h2stream(&mut self, fut: H2Stream<F, B>) { + self.execute(fut) + } +} + +impl<I, N, S, E, W> NewSvcExec<I, N, S, E, W> for Exec +where + NewSvcTask<I, N, S, E, W>: Future<Output = ()> + Send + 'static, + S: HttpService<Body>, + W: Watcher<I, S, E>, +{ + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>) { + self.execute(fut) + } +} + +// ==== impl Executor ===== + +impl<E, F, B> H2Exec<F, B> for E +where + E: Executor<H2Stream<F, B>> + Clone, + H2Stream<F, B>: Future<Output = ()>, + B: Payload, +{ + fn execute_h2stream(&mut self, fut: H2Stream<F, B>) { + self.execute(fut) + } +} + +impl<I, N, S, E, W> NewSvcExec<I, N, S, E, W> for E +where + E: Executor<NewSvcTask<I, N, S, E, W>> + Clone, + NewSvcTask<I, N, S, E, W>: Future<Output = ()>, + S: HttpService<Body>, + W: Watcher<I, S, E>, +{ + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>) { + self.execute(fut) + } +} diff --git a/third_party/rust/hyper/src/common/io/mod.rs b/third_party/rust/hyper/src/common/io/mod.rs new file mode 100644 index 0000000000..2e6d506153 --- /dev/null +++ b/third_party/rust/hyper/src/common/io/mod.rs @@ -0,0 +1,3 @@ +mod rewind; + +pub(crate) use self::rewind::Rewind; diff --git a/third_party/rust/hyper/src/common/io/rewind.rs b/third_party/rust/hyper/src/common/io/rewind.rs new file mode 100644 index 0000000000..14650697c3 --- /dev/null +++ b/third_party/rust/hyper/src/common/io/rewind.rs @@ -0,0 +1,153 @@ +use std::marker::Unpin; +use std::{cmp, io}; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::common::{task, Pin, Poll}; + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub(crate) struct Rewind<T> { + pre: Option<Bytes>, + inner: T, +} + +impl<T> Rewind<T> { + pub(crate) fn new(io: T) -> Self { + Rewind { + pre: None, + inner: io, + } + } + + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } + + pub(crate) fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } + + pub(crate) fn into_inner(self) -> (T, Bytes) { + (self.inner, self.pre.unwrap_or_else(Bytes::new)) + } + + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl<T> AsyncRead for Rewind<T> +where + T: AsyncRead + Unpin, +{ + #[inline] + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = cmp::min(prefix.len(), buf.len()); + prefix.copy_to_slice(&mut buf[..copy_len]); + // Put back whats left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(copy_len)); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl<T> AsyncWrite for Rewind<T> +where + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + #[inline] + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write_buf(cx, buf) + } +} + +#[cfg(test)] +mod tests { + // FIXME: re-implement tests with `async/await`, this import should + // trigger a warning to remind us + use super::Rewind; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn partial_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + // Read off some bytes, ensure we filled o1 + let mut buf = [0; 2]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&buf, &underlying); + } + + #[tokio::test] + async fn full_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + } +} diff --git a/third_party/rust/hyper/src/common/lazy.rs b/third_party/rust/hyper/src/common/lazy.rs new file mode 100644 index 0000000000..4d2e322c2c --- /dev/null +++ b/third_party/rust/hyper/src/common/lazy.rs @@ -0,0 +1,69 @@ +use std::mem; + +use super::{task, Future, Pin, Poll}; + +pub(crate) trait Started: Future { + fn started(&self) -> bool; +} + +pub(crate) fn lazy<F, R>(func: F) -> Lazy<F, R> +where + F: FnOnce() -> R, + R: Future + Unpin, +{ + Lazy { + inner: Inner::Init(func), + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(crate) struct Lazy<F, R> { + inner: Inner<F, R>, +} + +enum Inner<F, R> { + Init(F), + Fut(R), + Empty, +} + +impl<F, R> Started for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future + Unpin, +{ + fn started(&self) -> bool { + match self.inner { + Inner::Init(_) => false, + Inner::Fut(_) | Inner::Empty => true, + } + } +} + +impl<F, R> Future for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future + Unpin, +{ + type Output = R::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + if let Inner::Fut(ref mut f) = self.inner { + return Pin::new(f).poll(cx); + } + + match mem::replace(&mut self.inner, Inner::Empty) { + Inner::Init(func) => { + let mut fut = func(); + let ret = Pin::new(&mut fut).poll(cx); + self.inner = Inner::Fut(fut); + ret + } + _ => unreachable!("lazy state wrong"), + } + } +} + +// The closure `F` is never pinned +impl<F, R: Unpin> Unpin for Lazy<F, R> {} diff --git a/third_party/rust/hyper/src/common/mod.rs b/third_party/rust/hyper/src/common/mod.rs new file mode 100644 index 0000000000..e436fe5e2d --- /dev/null +++ b/third_party/rust/hyper/src/common/mod.rs @@ -0,0 +1,26 @@ +macro_rules! ready { + ($e:expr) => { + match $e { + std::task::Poll::Ready(v) => v, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + +pub(crate) mod buf; +pub(crate) mod drain; +pub(crate) mod exec; +pub(crate) mod io; +mod lazy; +mod never; +pub(crate) mod task; +pub(crate) mod watch; + +pub use self::exec::Executor; +pub(crate) use self::exec::{BoxSendFuture, Exec}; +pub(crate) use self::lazy::{lazy, Started as Lazy}; +pub use self::never::Never; +pub(crate) use self::task::Poll; + +// group up types normally needed for `Future` +pub(crate) use std::{future::Future, marker::Unpin, pin::Pin}; diff --git a/third_party/rust/hyper/src/common/never.rs b/third_party/rust/hyper/src/common/never.rs new file mode 100644 index 0000000000..f4fdb95ddd --- /dev/null +++ b/third_party/rust/hyper/src/common/never.rs @@ -0,0 +1,21 @@ +//! An uninhabitable type meaning it can never happen. +//! +//! To be replaced with `!` once it is stable. + +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub enum Never {} + +impl fmt::Display for Never { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self {} + } +} + +impl Error for Never { + fn description(&self) -> &str { + match *self {} + } +} diff --git a/third_party/rust/hyper/src/common/task.rs b/third_party/rust/hyper/src/common/task.rs new file mode 100644 index 0000000000..bfccfe3bfe --- /dev/null +++ b/third_party/rust/hyper/src/common/task.rs @@ -0,0 +1,10 @@ +use super::Never; +pub(crate) use std::task::{Context, Poll}; + +/// A function to help "yield" a future, such that it is re-scheduled immediately. +/// +/// Useful for spin counts, so a future doesn't hog too much time. +pub(crate) fn yield_now(cx: &mut Context<'_>) -> Poll<Never> { + cx.waker().wake_by_ref(); + Poll::Pending +} diff --git a/third_party/rust/hyper/src/common/watch.rs b/third_party/rust/hyper/src/common/watch.rs new file mode 100644 index 0000000000..ba17d551cb --- /dev/null +++ b/third_party/rust/hyper/src/common/watch.rs @@ -0,0 +1,73 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(crate) const CLOSED: usize = 0; + +pub(crate) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!( + initial != CLOSED, + "watch::channel initial state of 0 is reserved" + ); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + ( + Sender { + shared: shared.clone(), + }, + Receiver { shared }, + ) +} + +pub(crate) struct Sender { + shared: Arc<Shared>, +} + +pub(crate) struct Receiver { + shared: Arc<Shared>, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(crate) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/third_party/rust/hyper/src/error.rs b/third_party/rust/hyper/src/error.rs new file mode 100644 index 0000000000..99a994b879 --- /dev/null +++ b/third_party/rust/hyper/src/error.rs @@ -0,0 +1,454 @@ +//! Error and Result module. +use std::error::Error as StdError; +use std::fmt; +use std::io; + +/// Result type often returned from methods that can have hyper `Error`s. +pub type Result<T> = std::result::Result<T, Error>; + +type Cause = Box<dyn StdError + Send + Sync>; + +/// Represents errors that can occur handling HTTP streams. +pub struct Error { + inner: Box<ErrorImpl>, +} + +struct ErrorImpl { + kind: Kind, + cause: Option<Cause>, +} + +#[derive(Debug, PartialEq)] +pub(crate) enum Kind { + Parse(Parse), + User(User), + /// A message reached EOF, but is not complete. + IncompleteMessage, + /// A connection received a message (or bytes) when not waiting for one. + UnexpectedMessage, + /// A pending item was dropped before ever being processed. + Canceled, + /// Indicates a channel (client or body sender) is closed. + ChannelClosed, + /// An `io::Error` that occurred while trying to read or write to a network stream. + Io, + /// Error occurred while connecting. + Connect, + /// Error creating a TcpListener. + #[cfg(feature = "tcp")] + Listen, + /// Error accepting on an Incoming stream. + Accept, + /// Error while reading a body from connection. + Body, + /// Error while writing a body to connection. + BodyWrite, + /// The body write was aborted. + BodyWriteAborted, + /// Error calling AsyncWrite::shutdown() + Shutdown, + + /// A general error from h2. + Http2, +} + +#[derive(Debug, PartialEq)] +pub(crate) enum Parse { + Method, + Version, + VersionH2, + Uri, + Header, + TooLarge, + Status, +} + +#[derive(Debug, PartialEq)] +pub(crate) enum User { + /// Error calling user's Payload::poll_data(). + Body, + /// Error calling user's MakeService. + MakeService, + /// Error from future of user's Service. + Service, + /// User tried to send a certain header in an unexpected context. + /// + /// For example, sending both `content-length` and `transfer-encoding`. + UnexpectedHeader, + /// User tried to create a Request with bad version. + UnsupportedVersion, + /// User tried to create a CONNECT Request with the Client. + UnsupportedRequestMethod, + /// User tried to respond with a 1xx (not 101) response code. + UnsupportedStatusCode, + /// User tried to send a Request with Client with non-absolute URI. + AbsoluteUriRequired, + + /// User tried polling for an upgrade that doesn't exist. + NoUpgrade, + + /// User polled for an upgrade, but low-level API is not using upgrades. + ManualUpgrade, +} + +// Sentinel type to indicate the error was caused by a timeout. +#[derive(Debug)] +pub(crate) struct TimedOut; + +impl Error { + /// Returns true if this was an HTTP parse error. + pub fn is_parse(&self) -> bool { + match self.inner.kind { + Kind::Parse(_) => true, + _ => false, + } + } + + /// Returns true if this error was caused by user code. + pub fn is_user(&self) -> bool { + match self.inner.kind { + Kind::User(_) => true, + _ => false, + } + } + + /// Returns true if this was about a `Request` that was canceled. + pub fn is_canceled(&self) -> bool { + self.inner.kind == Kind::Canceled + } + + /// Returns true if a sender's channel is closed. + pub fn is_closed(&self) -> bool { + self.inner.kind == Kind::ChannelClosed + } + + /// Returns true if this was an error from `Connect`. + pub fn is_connect(&self) -> bool { + self.inner.kind == Kind::Connect + } + + /// Returns true if the connection closed before a message could complete. + pub fn is_incomplete_message(&self) -> bool { + self.inner.kind == Kind::IncompleteMessage + } + + /// Returns true if the body write was aborted. + pub fn is_body_write_aborted(&self) -> bool { + self.inner.kind == Kind::BodyWriteAborted + } + + /// Returns true if the error was caused by a timeout. + pub fn is_timeout(&self) -> bool { + self.find_source::<TimedOut>().is_some() + } + + /// Consumes the error, returning its cause. + pub fn into_cause(self) -> Option<Box<dyn StdError + Send + Sync>> { + self.inner.cause + } + + pub(crate) fn new(kind: Kind) -> Error { + Error { + inner: Box::new(ErrorImpl { kind, cause: None }), + } + } + + pub(crate) fn with<C: Into<Cause>>(mut self, cause: C) -> Error { + self.inner.cause = Some(cause.into()); + self + } + + pub(crate) fn kind(&self) -> &Kind { + &self.inner.kind + } + + fn find_source<E: StdError + 'static>(&self) -> Option<&E> { + let mut cause = self.source(); + while let Some(err) = cause { + if let Some(ref typed) = err.downcast_ref() { + return Some(typed); + } + cause = err.source(); + } + + // else + None + } + + pub(crate) fn h2_reason(&self) -> h2::Reason { + // Find an h2::Reason somewhere in the cause stack, if it exists, + // otherwise assume an INTERNAL_ERROR. + self.find_source::<h2::Error>() + .and_then(|h2_err| h2_err.reason()) + .unwrap_or(h2::Reason::INTERNAL_ERROR) + } + + pub(crate) fn new_canceled() -> Error { + Error::new(Kind::Canceled) + } + + pub(crate) fn new_incomplete() -> Error { + Error::new(Kind::IncompleteMessage) + } + + pub(crate) fn new_too_large() -> Error { + Error::new(Kind::Parse(Parse::TooLarge)) + } + + pub(crate) fn new_version_h2() -> Error { + Error::new(Kind::Parse(Parse::VersionH2)) + } + + pub(crate) fn new_unexpected_message() -> Error { + Error::new(Kind::UnexpectedMessage) + } + + pub(crate) fn new_io(cause: io::Error) -> Error { + Error::new(Kind::Io).with(cause) + } + + #[cfg(feature = "tcp")] + pub(crate) fn new_listen<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Listen).with(cause) + } + + pub(crate) fn new_accept<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Accept).with(cause) + } + + pub(crate) fn new_connect<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Connect).with(cause) + } + + pub(crate) fn new_closed() -> Error { + Error::new(Kind::ChannelClosed) + } + + pub(crate) fn new_body<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Body).with(cause) + } + + pub(crate) fn new_body_write<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::BodyWrite).with(cause) + } + + pub(crate) fn new_body_write_aborted() -> Error { + Error::new(Kind::BodyWriteAborted) + } + + fn new_user(user: User) -> Error { + Error::new(Kind::User(user)) + } + + pub(crate) fn new_user_header() -> Error { + Error::new_user(User::UnexpectedHeader) + } + + pub(crate) fn new_user_unsupported_version() -> Error { + Error::new_user(User::UnsupportedVersion) + } + + pub(crate) fn new_user_unsupported_request_method() -> Error { + Error::new_user(User::UnsupportedRequestMethod) + } + + pub(crate) fn new_user_unsupported_status_code() -> Error { + Error::new_user(User::UnsupportedStatusCode) + } + + pub(crate) fn new_user_absolute_uri_required() -> Error { + Error::new_user(User::AbsoluteUriRequired) + } + + pub(crate) fn new_user_no_upgrade() -> Error { + Error::new_user(User::NoUpgrade) + } + + pub(crate) fn new_user_manual_upgrade() -> Error { + Error::new_user(User::ManualUpgrade) + } + + pub(crate) fn new_user_make_service<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::MakeService).with(cause) + } + + pub(crate) fn new_user_service<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::Service).with(cause) + } + + pub(crate) fn new_user_body<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::Body).with(cause) + } + + pub(crate) fn new_shutdown(cause: io::Error) -> Error { + Error::new(Kind::Shutdown).with(cause) + } + + pub(crate) fn new_h2(cause: ::h2::Error) -> Error { + if cause.is_io() { + Error::new_io(cause.into_io().expect("h2::Error::is_io")) + } else { + Error::new(Kind::Http2).with(cause) + } + } + + fn description(&self) -> &str { + match self.inner.kind { + Kind::Parse(Parse::Method) => "invalid HTTP method parsed", + Kind::Parse(Parse::Version) => "invalid HTTP version parsed", + Kind::Parse(Parse::VersionH2) => "invalid HTTP version parsed (found HTTP2 preface)", + Kind::Parse(Parse::Uri) => "invalid URI", + Kind::Parse(Parse::Header) => "invalid HTTP header parsed", + Kind::Parse(Parse::TooLarge) => "message head is too large", + Kind::Parse(Parse::Status) => "invalid HTTP status-code parsed", + Kind::IncompleteMessage => "connection closed before message completed", + Kind::UnexpectedMessage => "received unexpected message from connection", + Kind::ChannelClosed => "channel closed", + Kind::Connect => "error trying to connect", + Kind::Canceled => "operation was canceled", + #[cfg(feature = "tcp")] + Kind::Listen => "error creating server listener", + Kind::Accept => "error accepting connection", + Kind::Body => "error reading a body from connection", + Kind::BodyWrite => "error writing a body to connection", + Kind::BodyWriteAborted => "body write aborted", + Kind::Shutdown => "error shutting down connection", + Kind::Http2 => "http2 error", + Kind::Io => "connection error", + + Kind::User(User::Body) => "error from user's Payload stream", + Kind::User(User::MakeService) => "error from user's MakeService", + Kind::User(User::Service) => "error from user's Service", + Kind::User(User::UnexpectedHeader) => "user sent unexpected header", + Kind::User(User::UnsupportedVersion) => "request has unsupported HTTP version", + Kind::User(User::UnsupportedRequestMethod) => "request has unsupported HTTP method", + Kind::User(User::UnsupportedStatusCode) => { + "response has 1xx status code, not supported by server" + } + Kind::User(User::AbsoluteUriRequired) => "client requires absolute-form URIs", + Kind::User(User::NoUpgrade) => "no upgrade available", + Kind::User(User::ManualUpgrade) => "upgrade expected but low level API in use", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("hyper::Error"); + f.field(&self.inner.kind); + if let Some(ref cause) = self.inner.cause { + f.field(cause); + } + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref cause) = self.inner.cause { + write!(f, "{}: {}", self.description(), cause) + } else { + f.write_str(self.description()) + } + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .cause + .as_ref() + .map(|cause| &**cause as &(dyn StdError + 'static)) + } +} + +#[doc(hidden)] +impl From<Parse> for Error { + fn from(err: Parse) -> Error { + Error::new(Kind::Parse(err)) + } +} + +impl From<httparse::Error> for Parse { + fn from(err: httparse::Error) -> Parse { + match err { + httparse::Error::HeaderName + | httparse::Error::HeaderValue + | httparse::Error::NewLine + | httparse::Error::Token => Parse::Header, + httparse::Error::Status => Parse::Status, + httparse::Error::TooManyHeaders => Parse::TooLarge, + httparse::Error::Version => Parse::Version, + } + } +} + +impl From<http::method::InvalidMethod> for Parse { + fn from(_: http::method::InvalidMethod) -> Parse { + Parse::Method + } +} + +impl From<http::status::InvalidStatusCode> for Parse { + fn from(_: http::status::InvalidStatusCode) -> Parse { + Parse::Status + } +} + +impl From<http::uri::InvalidUri> for Parse { + fn from(_: http::uri::InvalidUri) -> Parse { + Parse::Uri + } +} + +impl From<http::uri::InvalidUriParts> for Parse { + fn from(_: http::uri::InvalidUriParts) -> Parse { + Parse::Uri + } +} + +#[doc(hidden)] +trait AssertSendSync: Send + Sync + 'static {} +#[doc(hidden)] +impl AssertSendSync for Error {} + +// ===== impl TimedOut ==== + +impl fmt::Display for TimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("operation timed out") + } +} + +impl StdError for TimedOut {} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn error_size_of() { + assert_eq!(mem::size_of::<Error>(), mem::size_of::<usize>()); + } + + #[test] + fn h2_reason_unknown() { + let closed = Error::new_closed(); + assert_eq!(closed.h2_reason(), h2::Reason::INTERNAL_ERROR); + } + + #[test] + fn h2_reason_one_level() { + let body_err = Error::new_user_body(h2::Error::from(h2::Reason::ENHANCE_YOUR_CALM)); + assert_eq!(body_err.h2_reason(), h2::Reason::ENHANCE_YOUR_CALM); + } + + #[test] + fn h2_reason_nested() { + let recvd = Error::new_h2(h2::Error::from(h2::Reason::HTTP_1_1_REQUIRED)); + // Suppose a user were proxying the received error + let svc_err = Error::new_user_service(recvd); + assert_eq!(svc_err.h2_reason(), h2::Reason::HTTP_1_1_REQUIRED); + } +} diff --git a/third_party/rust/hyper/src/headers.rs b/third_party/rust/hyper/src/headers.rs new file mode 100644 index 0000000000..5375e78287 --- /dev/null +++ b/third_party/rust/hyper/src/headers.rs @@ -0,0 +1,115 @@ +use bytes::BytesMut; +use http::header::{HeaderValue, OccupiedEntry, ValueIter}; +use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING}; +use http::method::Method; +use http::HeaderMap; + +pub fn connection_keep_alive(value: &HeaderValue) -> bool { + connection_has(value, "keep-alive") +} + +pub fn connection_close(value: &HeaderValue) -> bool { + connection_has(value, "close") +} + +fn connection_has(value: &HeaderValue, needle: &str) -> bool { + if let Ok(s) = value.to_str() { + for val in s.split(',') { + if val.trim().eq_ignore_ascii_case(needle) { + return true; + } + } + } + false +} + +pub fn content_length_parse(value: &HeaderValue) -> Option<u64> { + value.to_str().ok().and_then(|s| s.parse().ok()) +} + +pub fn content_length_parse_all(headers: &HeaderMap) -> Option<u64> { + content_length_parse_all_values(headers.get_all(CONTENT_LENGTH).into_iter()) +} + +pub fn content_length_parse_all_values(values: ValueIter<'_, HeaderValue>) -> Option<u64> { + // If multiple Content-Length headers were sent, everything can still + // be alright if they all contain the same value, and all parse + // correctly. If not, then it's an error. + + let folded = values.fold(None, |prev, line| match prev { + Some(Ok(prev)) => Some( + line.to_str() + .map_err(|_| ()) + .and_then(|s| s.parse().map_err(|_| ())) + .and_then(|n| if prev == n { Ok(n) } else { Err(()) }), + ), + None => Some( + line.to_str() + .map_err(|_| ()) + .and_then(|s| s.parse().map_err(|_| ())), + ), + Some(Err(())) => Some(Err(())), + }); + + if let Some(Ok(n)) = folded { + Some(n) + } else { + None + } +} + +pub fn method_has_defined_payload_semantics(method: &Method) -> bool { + match *method { + Method::GET | Method::HEAD | Method::DELETE | Method::CONNECT => false, + _ => true, + } +} + +pub fn set_content_length_if_missing(headers: &mut HeaderMap, len: u64) { + headers + .entry(CONTENT_LENGTH) + .or_insert_with(|| HeaderValue::from(len)); +} + +pub fn transfer_encoding_is_chunked(headers: &HeaderMap) -> bool { + is_chunked(headers.get_all(TRANSFER_ENCODING).into_iter()) +} + +pub fn is_chunked(mut encodings: ValueIter<'_, HeaderValue>) -> bool { + // chunked must always be the last encoding, according to spec + if let Some(line) = encodings.next_back() { + return is_chunked_(line); + } + + false +} + +pub fn is_chunked_(value: &HeaderValue) -> bool { + // chunked must always be the last encoding, according to spec + if let Ok(s) = value.to_str() { + if let Some(encoding) = s.rsplit(',').next() { + return encoding.trim().eq_ignore_ascii_case("chunked"); + } + } + + false +} + +pub fn add_chunked(mut entry: OccupiedEntry<'_, HeaderValue>) { + const CHUNKED: &str = "chunked"; + + if let Some(line) = entry.iter_mut().next_back() { + // + 2 for ", " + let new_cap = line.as_bytes().len() + CHUNKED.len() + 2; + let mut buf = BytesMut::with_capacity(new_cap); + buf.copy_from_slice(line.as_bytes()); + buf.copy_from_slice(b", "); + buf.copy_from_slice(CHUNKED.as_bytes()); + + *line = HeaderValue::from_maybe_shared(buf.freeze()) + .expect("original header value plus ascii is valid"); + return; + } + + entry.insert(HeaderValue::from_static(CHUNKED)); +} diff --git a/third_party/rust/hyper/src/lib.rs b/third_party/rust/hyper/src/lib.rs new file mode 100644 index 0000000000..7e19524fba --- /dev/null +++ b/third_party/rust/hyper/src/lib.rs @@ -0,0 +1,71 @@ +#![doc(html_root_url = "https://docs.rs/hyper/0.13.5")] +#![deny(missing_docs)] +#![deny(missing_debug_implementations)] +#![cfg_attr(test, deny(rust_2018_idioms))] +#![cfg_attr(test, deny(warnings))] +#![cfg_attr(all(test, feature = "nightly"), feature(test))] + +//! # hyper +//! +//! hyper is a **fast** and **correct** HTTP implementation written in and for Rust. +//! +//! ## Features +//! +//! - HTTP/1 and HTTP/2 +//! - Asynchronous design +//! - Leading in performance +//! - Tested and **correct** +//! - Extensive production use +//! - [Client](client/index.html) and [Server](server/index.html) APIs +//! +//! If just starting out, **check out the [Guides](https://hyper.rs/guides) +//! first.** +//! +//! ## "Low-level" +//! +//! hyper is a lower-level HTTP library, meant to be a building block +//! for libraries and applications. +//! +//! If looking for just a convenient HTTP client, consider the +//! [reqwest](https://crates.io/crates/reqwest) crate. +//! +//! # Optional Features +//! +//! The following optional features are available: +//! +//! - `runtime` (*enabled by default*): Enables convenient integration with +//! `tokio`, providing connectors and acceptors for TCP, and a default +//! executor. +//! - `tcp` (*enabled by default*): Enables convenient implementations over +//! TCP (using tokio). +//! - `stream` (*enabled by default*): Provides `futures::Stream` capabilities. + +#[doc(hidden)] +pub use http; +#[macro_use] +extern crate log; + +#[cfg(all(test, feature = "nightly"))] +extern crate test; + +pub use http::{header, HeaderMap, Method, Request, Response, StatusCode, Uri, Version}; + +pub use crate::body::Body; +pub use crate::client::Client; +pub use crate::error::{Error, Result}; +pub use crate::server::Server; + +#[macro_use] +mod common; +pub mod body; +pub mod client; +#[doc(hidden)] // Mistakenly public... +pub mod error; +mod headers; +#[cfg(test)] +mod mock; +mod proto; +pub mod rt; +pub mod server; +pub mod service; +pub mod upgrade; diff --git a/third_party/rust/hyper/src/mock.rs b/third_party/rust/hyper/src/mock.rs new file mode 100644 index 0000000000..1dd57de319 --- /dev/null +++ b/third_party/rust/hyper/src/mock.rs @@ -0,0 +1,235 @@ +// FIXME: re-implement tests with `async/await` +/* +#[cfg(feature = "runtime")] +use std::collections::HashMap; +use std::cmp; +use std::io::{self, Read, Write}; +#[cfg(feature = "runtime")] +use std::sync::{Arc, Mutex}; + +use bytes::Buf; +use futures::{Async, Poll}; +#[cfg(feature = "runtime")] +use futures::Future; +use futures::task::{self, Task}; +use tokio_io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "runtime")] +use crate::client::connect::{Connect, Connected, Destination}; + + + +#[cfg(feature = "runtime")] +pub struct Duplex { + inner: Arc<Mutex<DuplexInner>>, +} + +#[cfg(feature = "runtime")] +struct DuplexInner { + handle_read_task: Option<Task>, + read: AsyncIo<MockCursor>, + write: AsyncIo<MockCursor>, +} + +#[cfg(feature = "runtime")] +impl Duplex { + pub(crate) fn channel() -> (Duplex, DuplexHandle) { + let mut inner = DuplexInner { + handle_read_task: None, + read: AsyncIo::new_buf(Vec::new(), 0), + write: AsyncIo::new_buf(Vec::new(), std::usize::MAX), + }; + + inner.read.park_tasks(true); + inner.write.park_tasks(true); + + let inner = Arc::new(Mutex::new(inner)); + + let duplex = Duplex { + inner: inner.clone(), + }; + let handle = DuplexHandle { + inner: inner, + }; + + (duplex, handle) + } +} + +#[cfg(feature = "runtime")] +impl Read for Duplex { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.lock().unwrap().read.read(buf) + } +} + +#[cfg(feature = "runtime")] +impl Write for Duplex { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let mut inner = self.inner.lock().unwrap(); + let ret = inner.write.write(buf); + if let Some(task) = inner.handle_read_task.take() { + trace!("waking DuplexHandle read"); + task.notify(); + } + ret + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.lock().unwrap().write.flush() + } +} + +#[cfg(feature = "runtime")] +impl AsyncRead for Duplex { +} + +#[cfg(feature = "runtime")] +impl AsyncWrite for Duplex { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(().into()) + } + + fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + if let Some(task) = inner.handle_read_task.take() { + task.notify(); + } + inner.write.write_buf(buf) + } +} + +#[cfg(feature = "runtime")] +pub struct DuplexHandle { + inner: Arc<Mutex<DuplexInner>>, +} + +#[cfg(feature = "runtime")] +impl DuplexHandle { + pub fn read(&self, buf: &mut [u8]) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + assert!(buf.len() >= inner.write.inner.len()); + if inner.write.inner.is_empty() { + trace!("DuplexHandle read parking"); + inner.handle_read_task = Some(task::current()); + return Ok(Async::NotReady); + } + inner.write.read(buf).map(Async::Ready) + } + + pub fn write(&self, bytes: &[u8]) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + assert_eq!(inner.read.inner.pos, 0); + assert_eq!(inner.read.inner.vec.len(), 0, "write but read isn't empty"); + inner + .read + .inner + .vec + .extend(bytes); + inner.read.block_in(bytes.len()); + Ok(Async::Ready(bytes.len())) + } +} + +#[cfg(feature = "runtime")] +impl Drop for DuplexHandle { + fn drop(&mut self) { + trace!("mock duplex handle drop"); + if !::std::thread::panicking() { + let mut inner = self.inner.lock().unwrap(); + inner.read.close(); + inner.write.close(); + } + } +} + +#[cfg(feature = "runtime")] +type BoxedConnectFut = Box<dyn Future<Item=(Duplex, Connected), Error=io::Error> + Send>; + +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub struct MockConnector { + mocks: Arc<Mutex<MockedConnections>>, +} + +#[cfg(feature = "runtime")] +struct MockedConnections(HashMap<String, Vec<BoxedConnectFut>>); + +#[cfg(feature = "runtime")] +impl MockConnector { + pub fn new() -> MockConnector { + MockConnector { + mocks: Arc::new(Mutex::new(MockedConnections(HashMap::new()))), + } + } + + pub fn mock(&mut self, key: &str) -> DuplexHandle { + use futures::future; + self.mock_fut(key, future::ok::<_, ()>(())) + } + + pub fn mock_fut<F>(&mut self, key: &str, fut: F) -> DuplexHandle + where + F: Future + Send + 'static, + { + self.mock_opts(key, Connected::new(), fut) + } + + pub fn mock_opts<F>(&mut self, key: &str, connected: Connected, fut: F) -> DuplexHandle + where + F: Future + Send + 'static, + { + let key = key.to_owned(); + + let (duplex, handle) = Duplex::channel(); + + let fut = Box::new(fut.then(move |_| { + trace!("MockConnector mocked fut ready"); + Ok((duplex, connected)) + })); + self.mocks.lock().unwrap().0.entry(key) + .or_insert(Vec::new()) + .push(fut); + + handle + } +} + +#[cfg(feature = "runtime")] +impl Connect for MockConnector { + type Transport = Duplex; + type Error = io::Error; + type Future = BoxedConnectFut; + + fn connect(&self, dst: Destination) -> Self::Future { + trace!("mock connect: {:?}", dst); + let key = format!("{}://{}{}", dst.scheme(), dst.host(), if let Some(port) = dst.port() { + format!(":{}", port) + } else { + "".to_owned() + }); + let mut mocks = self.mocks.lock().unwrap(); + let mocks = mocks.0.get_mut(&key) + .expect(&format!("unknown mocks uri: {}", key)); + assert!(!mocks.is_empty(), "no additional mocks for {}", key); + mocks.remove(0) + } +} + + +#[cfg(feature = "runtime")] +impl Drop for MockedConnections { + fn drop(&mut self) { + if !::std::thread::panicking() { + for (key, mocks) in self.0.iter() { + assert_eq!( + mocks.len(), + 0, + "not all mocked connects for {:?} were used", + key, + ); + } + } + } +} +*/ diff --git a/third_party/rust/hyper/src/proto/h1/conn.rs b/third_party/rust/hyper/src/proto/h1/conn.rs new file mode 100644 index 0000000000..c8b355cd63 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/conn.rs @@ -0,0 +1,1321 @@ +use std::fmt; +use std::io::{self}; +use std::marker::PhantomData; + +use bytes::{Buf, Bytes}; +use http::header::{HeaderValue, CONNECTION}; +use http::{HeaderMap, Method, Version}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::io::Buffered; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; +use crate::common::{task, Pin, Poll, Unpin}; +use crate::headers::connection_keep_alive; +use crate::proto::{BodyLength, DecodedLength, MessageHead}; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// This handles a connection, which will have been established over an +/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Transaction`s over HTTP. +/// +/// The connection will determine when a message begins and ends as well as +/// determine if this connection can be kept alive after the message, +/// or if it is complete. +pub(crate) struct Conn<I, B, T> { + io: Buffered<I, EncodedBuf<B>>, + state: State, + _marker: PhantomData<fn(T)>, +} + +impl<I, B, T> Conn<I, B, T> +where + I: AsyncRead + AsyncWrite + Unpin, + B: Buf, + T: Http1Transaction, +{ + pub fn new(io: I) -> Conn<I, B, T> { + Conn { + io: Buffered::new(io), + state: State { + allow_half_close: false, + cached_headers: None, + error: None, + keep_alive: KA::Busy, + method: None, + title_case_headers: false, + notify_read: false, + reading: Reading::Init, + writing: Writing::Init, + upgrade: None, + // We assume a modern world where the remote speaks HTTP/1.1. + // If they tell us otherwise, we'll downgrade in `read_head`. + version: Version::HTTP_11, + }, + _marker: PhantomData, + } + } + + pub fn set_flush_pipeline(&mut self, enabled: bool) { + self.io.set_flush_pipeline(enabled); + } + + pub fn set_max_buf_size(&mut self, max: usize) { + self.io.set_max_buf_size(max); + } + + pub fn set_read_buf_exact_size(&mut self, sz: usize) { + self.io.set_read_buf_exact_size(sz); + } + + pub fn set_write_strategy_flatten(&mut self) { + self.io.set_write_strategy_flatten(); + } + + pub fn set_title_case_headers(&mut self) { + self.state.title_case_headers = true; + } + + pub(crate) fn set_allow_half_close(&mut self) { + self.state.allow_half_close = true; + } + + pub fn into_inner(self) -> (I, Bytes) { + self.io.into_inner() + } + + pub fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> { + self.state.upgrade.take() + } + + pub fn is_read_closed(&self) -> bool { + self.state.is_read_closed() + } + + pub fn is_write_closed(&self) -> bool { + self.state.is_write_closed() + } + + pub fn can_read_head(&self) -> bool { + match self.state.reading { + Reading::Init => { + if T::should_read_first() { + true + } else { + match self.state.writing { + Writing::Init => false, + _ => true, + } + } + } + _ => false, + } + } + + pub fn can_read_body(&self) -> bool { + match self.state.reading { + Reading::Body(..) | Reading::Continue(..) => true, + _ => false, + } + } + + fn should_error_on_eof(&self) -> bool { + // If we're idle, it's probably just the connection closing gracefully. + T::should_error_on_parse_eof() && !self.state.is_idle() + } + + fn has_h2_prefix(&self) -> bool { + let read_buf = self.io.read_buf(); + read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE + } + + pub(super) fn poll_read_head( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> { + debug_assert!(self.can_read_head()); + trace!("Conn::read_head"); + + let msg = match ready!(self.io.parse::<T>( + cx, + ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + } + )) { + Ok(msg) => msg, + Err(e) => return self.on_read_head_error(e), + }; + + // Note: don't deconstruct `msg` into local variables, it appears + // the optimizer doesn't remove the extra copies. + + debug!("incoming body is {}", msg.decode); + + self.state.busy(); + self.state.keep_alive &= msg.keep_alive; + self.state.version = msg.head.version; + + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + + if msg.decode == DecodedLength::ZERO { + if msg.expect_continue { + debug!("ignoring expect-continue since body is empty"); + } + self.state.reading = Reading::KeepAlive; + if !T::should_read_first() { + self.try_keep_alive(cx); + } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); + } else { + self.state.reading = Reading::Body(Decoder::new(msg.decode)); + } + + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) + } + + fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); + if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface + debug!( + "parse error ({}) with {} bytes", + e, + self.io.read_buf().len() + ); + match self.on_parse_error(e) { + Ok(()) => Poll::Pending, // XXX: wat? + Err(e) => Poll::Ready(Some(Err(e))), + } + } else { + debug!("read eof"); + self.close_write(); + Poll::Ready(None) + } + } + + pub fn poll_read_body( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<io::Result<Bytes>>> { + debug_assert!(self.can_read_body()); + + let (reading, ret) = match self.state.reading { + Reading::Body(ref mut decoder) => { + match decoder.decode(cx, &mut self.io) { + Poll::Ready(Ok(slice)) => { + let (reading, chunk) = if decoder.is_eof() { + debug!("incoming body completed"); + ( + Reading::KeepAlive, + if !slice.is_empty() { + Some(Ok(slice)) + } else { + None + }, + ) + } else if slice.is_empty() { + error!("incoming body unexpectedly ended"); + // This should be unreachable, since all 3 decoders + // either set eof=true or return an Err when reading + // an empty slice... + (Reading::Closed, None) + } else { + return Poll::Ready(Some(Ok(slice))); + }; + (reading, Poll::Ready(chunk)) + } + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + debug!("incoming body decode error: {}", e); + (Reading::Closed, Poll::Ready(Some(Err(e)))) + } + } + } + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), + }; + + self.state.reading = reading; + self.try_keep_alive(cx); + ret + } + + pub fn wants_read_again(&mut self) -> bool { + let ret = self.state.notify_read; + self.state.notify_read = false; + ret + } + + pub fn poll_read_keep_alive(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body()); + + if self.is_read_closed() { + Poll::Pending + } else if self.is_mid_message() { + self.mid_message_detect_eof(cx) + } else { + self.require_empty_read(cx) + } + } + + fn is_mid_message(&self) -> bool { + match (&self.state.reading, &self.state.writing) { + (&Reading::Init, &Writing::Init) => false, + _ => true, + } + } + + // This will check to make sure the io object read is empty. + // + // This should only be called for Clients wanting to enter the idle + // state. + fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(!self.is_mid_message()); + debug_assert!(T::is_client()); + + if !self.io.read_buf().is_empty() { + debug!("received an unexpected {} bytes", self.io.read_buf().len()); + return Poll::Ready(Err(crate::Error::new_unexpected_message())); + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + let ret = if self.should_error_on_eof() { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + trace!("found EOF on idle connection, closing"); + Poll::Ready(Ok(())) + }; + + // order is important: should_error needs state BEFORE close_read + self.state.close_read(); + return ret; + } + + debug!( + "received unexpected {} bytes on an idle connection", + num_read + ); + Poll::Ready(Err(crate::Error::new_unexpected_message())) + } + + fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(self.is_mid_message()); + + if self.state.allow_half_close || !self.io.read_buf().is_empty() { + return Poll::Pending; + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + self.state.close_read(); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + Poll::Ready(Ok(())) + } + } + + fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + debug_assert!(!self.state.is_read_closed()); + + let result = ready!(self.io.poll_read_from_io(cx)); + Poll::Ready(result.map_err(|e| { + trace!("force_io_read; io error = {:?}", e); + self.state.close(); + e + })) + } + + fn maybe_notify(&mut self, cx: &mut task::Context<'_>) { + // its possible that we returned NotReady from poll() without having + // exhausted the underlying Io. We would have done this when we + // determined we couldn't keep reading until we knew how writing + // would finish. + + match self.state.reading { + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } + Reading::Init => (), + }; + + match self.state.writing { + Writing::Body(..) => return, + Writing::Init | Writing::KeepAlive | Writing::Closed => (), + } + + if !self.io.is_read_blocked() { + if self.io.read_buf().is_empty() { + match self.io.poll_read_from_io(cx) { + Poll::Ready(Ok(n)) => { + if n == 0 { + trace!("maybe_notify; read eof"); + if self.state.is_idle() { + self.state.close(); + } else { + self.close_read() + } + return; + } + } + Poll::Pending => { + trace!("maybe_notify; read_from_io blocked"); + return; + } + Poll::Ready(Err(e)) => { + trace!("maybe_notify; read_from_io error: {}", e); + self.state.close(); + self.state.error = Some(crate::Error::new_io(e)); + } + } + } + self.state.notify_read = true; + } + } + + fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) { + self.state.try_keep_alive::<T>(); + self.maybe_notify(cx); + } + + pub fn can_write_head(&self) -> bool { + if !T::should_read_first() { + if let Reading::Closed = self.state.reading { + return false; + } + } + match self.state.writing { + Writing::Init => true, + _ => false, + } + } + + pub fn can_write_body(&self) -> bool { + match self.state.writing { + Writing::Body(..) => true, + Writing::Init | Writing::KeepAlive | Writing::Closed => false, + } + } + + pub fn can_buffer_body(&self) -> bool { + self.io.can_buffer() + } + + pub fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) { + if let Some(encoder) = self.encode_head(head, body) { + self.state.writing = if !encoder.is_eof() { + Writing::Body(encoder) + } else if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + + pub fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) { + if let Some(encoder) = + self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64))) + { + let is_last = encoder.is_last(); + // Make sure we don't write a body if we weren't actually allowed + // to do so, like because its a HEAD request. + if !encoder.is_eof() { + encoder.danger_full_buf(body, self.io.write_buf()); + } + self.state.writing = if is_last { + Writing::Closed + } else { + Writing::KeepAlive + } + } + } + + fn encode_head( + &mut self, + mut head: MessageHead<T::Outgoing>, + body: Option<BodyLength>, + ) -> Option<Encoder> { + debug_assert!(self.can_write_head()); + + if !T::should_read_first() { + self.state.busy(); + } + + self.enforce_version(&mut head); + + let buf = self.io.headers_buf(); + match T::encode( + Encode { + head: &mut head, + body, + keep_alive: self.state.wants_keep_alive(), + req_method: &mut self.state.method, + title_case_headers: self.state.title_case_headers, + }, + buf, + ) { + Ok(encoder) => { + debug_assert!(self.state.cached_headers.is_none()); + debug_assert!(head.headers.is_empty()); + self.state.cached_headers = Some(head.headers); + Some(encoder) + } + Err(err) => { + self.state.error = Some(err); + self.state.writing = Writing::Closed; + None + } + } + } + + // Fix keep-alives when Connection: keep-alive header is not present + fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) { + let outgoing_is_keep_alive = head + .headers + .get(CONNECTION) + .map(connection_keep_alive) + .unwrap_or(false); + + if !outgoing_is_keep_alive { + match head.version { + // If response is version 1.0 and keep-alive is not present in the response, + // disable keep-alive so the server closes the connection + Version::HTTP_10 => self.state.disable_keep_alive(), + // If response is version 1.1 and keep-alive is wanted, add + // Connection: keep-alive header when not present + Version::HTTP_11 => { + if self.state.wants_keep_alive() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } + _ => (), + } + } + } + + // If we know the remote speaks an older version, we try to fix up any messages + // to work with our older peer. + fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) { + if let Version::HTTP_10 = self.state.version { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = Version::HTTP_10; + } + // If the remote speaks HTTP/1.1, then it *should* be fine with + // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let + // the user's headers be. + } + + pub fn write_body(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + self.io.buffer(encoder.encode(chunk)); + + if encoder.is_eof() { + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } else { + return; + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub fn write_body_and_end(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref encoder) => { + let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf()); + if can_keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub fn end_body(&mut self) { + debug_assert!(self.can_write_body()); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + // end of stream, that means we should try to eof + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } + Err(_not_eof) => Writing::Closed, + } + } + _ => return, + }; + + self.state.writing = state; + } + + // When we get a parse error, depending on what side we are, we might be able + // to write a response before closing the connection. + // + // - Client: there is nothing we can do + // - Server: if Response hasn't been written yet, we can send a 4xx response + fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> { + if let Writing::Init = self.state.writing { + if self.has_h2_prefix() { + return Err(crate::Error::new_version_h2()); + } + if let Some(msg) = T::on_error(&err) { + // Drop the cached headers so as to not trigger a debug + // assert in `write_head`... + self.state.cached_headers.take(); + self.write_head(msg, None); + self.state.error = Some(err); + return Ok(()); + } + } + + // fallback is pass the error back up + Err(err) + } + + pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + ready!(Pin::new(&mut self.io).poll_flush(cx))?; + self.try_keep_alive(cx); + trace!("flushed({}): {:?}", T::LOG, self.state); + Poll::Ready(Ok(())) + } + + pub fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) { + Ok(()) => { + trace!("shut down IO complete"); + Poll::Ready(Ok(())) + } + Err(e) => { + debug!("error shutting down IO: {}", e); + Poll::Ready(Err(e)) + } + } + } + + /// If the read side can be cheaply drained, do so. Otherwise, close. + pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) { + let _ = self.poll_read_body(cx); + + // If still in Reading::Body, just give up + match self.state.reading { + Reading::Init | Reading::KeepAlive => { + trace!("body drained"); + return; + } + _ => self.close_read(), + } + } + + pub fn close_read(&mut self) { + self.state.close_read(); + } + + pub fn close_write(&mut self) { + self.state.close_write(); + } + + pub fn disable_keep_alive(&mut self) { + if self.state.is_idle() { + trace!("disable_keep_alive; closing idle connection"); + self.state.close(); + } else { + trace!("disable_keep_alive; in-progress connection"); + self.state.disable_keep_alive(); + } + } + + pub fn take_error(&mut self) -> crate::Result<()> { + if let Some(err) = self.state.error.take() { + Err(err) + } else { + Ok(()) + } + } + + pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + trace!("{}: prepare possible HTTP upgrade", T::LOG); + self.state.prepare_upgrade() + } +} + +impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("state", &self.state) + .field("io", &self.io) + .finish() + } +} + +// B and T are never pinned +impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {} + +struct State { + allow_half_close: bool, + /// Re-usable HeaderMap to reduce allocating new ones. + cached_headers: Option<HeaderMap>, + /// If an error occurs when there wasn't a direct way to return it + /// back to the user, this is set. + error: Option<crate::Error>, + /// Current keep-alive status. + keep_alive: KA, + /// If mid-message, the HTTP Method that started it. + /// + /// This is used to know things such as if the message can include + /// a body or not. + method: Option<Method>, + title_case_headers: bool, + /// Set to true when the Dispatcher should poll read operations + /// again. See the `maybe_notify` method for more. + notify_read: bool, + /// State of allowed reads + reading: Reading, + /// State of allowed writes + writing: Writing, + /// An expected pending HTTP upgrade. + upgrade: Option<crate::upgrade::Pending>, + /// Either HTTP/1.0 or 1.1 connection + version: Version, +} + +#[derive(Debug)] +enum Reading { + Init, + Continue(Decoder), + Body(Decoder), + KeepAlive, + Closed, +} + +enum Writing { + Init, + Body(Encoder), + KeepAlive, + Closed, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("State"); + builder + .field("reading", &self.reading) + .field("writing", &self.writing) + .field("keep_alive", &self.keep_alive); + + // Only show error field if it's interesting... + if let Some(ref error) = self.error { + builder.field("error", error); + } + + if self.allow_half_close { + builder.field("allow_half_close", &true); + } + + // Purposefully leaving off other fields.. + + builder.finish() + } +} + +impl fmt::Debug for Writing { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Writing::Init => f.write_str("Init"), + Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(), + Writing::KeepAlive => f.write_str("KeepAlive"), + Writing::Closed => f.write_str("Closed"), + } + } +} + +impl std::ops::BitAndAssign<bool> for KA { + fn bitand_assign(&mut self, enabled: bool) { + if !enabled { + trace!("remote disabling keep-alive"); + *self = KA::Disabled; + } + } +} + +#[derive(Clone, Copy, Debug)] +enum KA { + Idle, + Busy, + Disabled, +} + +impl Default for KA { + fn default() -> KA { + KA::Busy + } +} + +impl KA { + fn idle(&mut self) { + *self = KA::Idle; + } + + fn busy(&mut self) { + *self = KA::Busy; + } + + fn disable(&mut self) { + *self = KA::Disabled; + } + + fn status(&self) -> KA { + *self + } +} + +impl State { + fn close(&mut self) { + trace!("State::close()"); + self.reading = Reading::Closed; + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn close_read(&mut self) { + trace!("State::close_read()"); + self.reading = Reading::Closed; + self.keep_alive.disable(); + } + + fn close_write(&mut self) { + trace!("State::close_write()"); + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn wants_keep_alive(&self) -> bool { + if let KA::Disabled = self.keep_alive.status() { + false + } else { + true + } + } + + fn try_keep_alive<T: Http1Transaction>(&mut self) { + match (&self.reading, &self.writing) { + (&Reading::KeepAlive, &Writing::KeepAlive) => { + if let KA::Busy = self.keep_alive.status() { + self.idle::<T>(); + } else { + trace!( + "try_keep_alive({}): could keep-alive, but status = {:?}", + T::LOG, + self.keep_alive + ); + self.close(); + } + } + (&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => { + self.close() + } + _ => (), + } + } + + fn disable_keep_alive(&mut self) { + self.keep_alive.disable() + } + + fn busy(&mut self) { + if let KA::Disabled = self.keep_alive.status() { + return; + } + self.keep_alive.busy(); + } + + fn idle<T: Http1Transaction>(&mut self) { + debug_assert!(!self.is_idle(), "State::idle() called while idle"); + + self.method = None; + self.keep_alive.idle(); + if self.is_idle() { + self.reading = Reading::Init; + self.writing = Writing::Init; + + // !T::should_read_first() means Client. + // + // If Client connection has just gone idle, the Dispatcher + // should try the poll loop one more time, so as to poll the + // pending requests stream. + if !T::should_read_first() { + self.notify_read = true; + } + } else { + self.close(); + } + } + + fn is_idle(&self) -> bool { + if let KA::Idle = self.keep_alive.status() { + true + } else { + false + } + } + + fn is_read_closed(&self) -> bool { + match self.reading { + Reading::Closed => true, + _ => false, + } + } + + fn is_write_closed(&self) -> bool { + match self.writing { + Writing::Closed => true, + _ => false, + } + } + + fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + let (tx, rx) = crate::upgrade::pending(); + self.upgrade = Some(tx); + rx + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + #[bench] + fn bench_read_head_short(b: &mut ::test::Bencher) { + use super::*; + let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; + let len = s.len(); + b.bytes = len as u64; + + // an empty IO, we'll be skipping and using the read buffer anyways + let io = tokio_test::io::Builder::new().build(); + let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); + *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); + conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); + + let mut rt = tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .unwrap(); + + b.iter(|| { + rt.block_on(futures_util::future::poll_fn(|cx| { + match conn.poll_read_head(cx) { + Poll::Ready(Some(Ok(x))) => { + ::test::black_box(&x); + let mut headers = x.0.headers; + headers.clear(); + conn.state.cached_headers = Some(headers); + } + f => panic!("expected Ready(Some(Ok(..))): {:?}", f), + } + + conn.io.read_buf_mut().reserve(1); + unsafe { + conn.io.read_buf_mut().set_len(len); + } + conn.state.reading = Reading::Init; + Poll::Ready(()) + })); + }); + } + + /* + //TODO: rewrite these using dispatch... someday... + use futures::{Async, Future, Stream, Sink}; + use futures::future; + + use proto::{self, ClientTransaction, MessageHead, ServerTransaction}; + use super::super::Encoder; + use mock::AsyncIo; + + use super::{Conn, Decoder, Reading, Writing}; + use ::uri::Uri; + + use std::str::FromStr; + + #[test] + fn test_conn_init_read() { + let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec(); + let len = good_message.len(); + let io = AsyncIo::new_buf(good_message, len); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + + match conn.poll().unwrap() { + Async::Ready(Some(Frame::Message { message, body: false })) => { + assert_eq!(message, MessageHead { + subject: ::proto::RequestLine(::Get, Uri::from_str("/").unwrap()), + .. MessageHead::default() + }) + }, + f => panic!("frame is not Frame::Message: {:?}", f) + } + } + + #[test] + fn test_conn_parse_partial() { + let _: Result<(), ()> = future::lazy(|| { + let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec(); + let io = AsyncIo::new_buf(good_message, 10); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + assert!(conn.poll().unwrap().is_not_ready()); + conn.io.io_mut().block_in(50); + let async = conn.poll().unwrap(); + assert!(async.is_ready()); + match async { + Async::Ready(Some(Frame::Message { .. })) => (), + f => panic!("frame is not Message: {:?}", f), + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_init_read_eof_idle() { + let io = AsyncIo::new_buf(vec![], 1); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_idle_partial_parse() { + let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_busy() { + let _: Result<(), ()> = future::lazy(|| { + // server ignores + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("unexpected frame: {:?}", other) + } + + // client + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_finish_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + conn.state.reading = Reading::Body(Decoder::length(0)); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_message_empty_body_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: false, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_read_body_end() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: true, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: Some(_) }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // When the body is done, `poll` MUST return a `Body` frame with chunk set to `None` + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::NotReady) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_closed_read() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_body_write_length() { + let _ = pretty_env_logger::try_init(); + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096; + conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64)); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; max].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'b'; 1024 * 8].into()) }).unwrap().is_not_ready()); + + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(max * 2); + assert!(conn.poll_complete().unwrap().is_ready()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'c'; 1024 * 8].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_write_chunked() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::chunked()); + + assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'x'; 8192].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_flush() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::length(1024 * 1024)); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + conn.io.io_mut().block_in(1024 * 1024 * 5); + assert!(conn.poll_complete().unwrap().is_ready()); + assert!(conn.can_buffer_body()); + assert!(conn.io.io_mut().flushed()); + + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_parking() { + use std::sync::Arc; + use futures::executor::Notify; + use futures::executor::NotifyHandle; + + struct Car { + permit: bool, + } + impl Notify for Car { + fn notify(&self, _id: usize) { + assert!(self.permit, "unparked without permit"); + } + } + + fn car(permit: bool) -> NotifyHandle { + Arc::new(Car { + permit: permit, + }).into() + } + + // test that once writing is done, unparks + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(true), 0).unwrap(); + + + // test that flushing when not waiting on read doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + + + // test that flushing and writing isn't done doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + conn.state.writing = Writing::Body(Encoder::length(5_000)); + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + } + + #[test] + fn test_conn_closed_write() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) { + Err(_e) => {}, + other => panic!("did not return Err: {:?}", other) + } + + assert!(conn.state.is_write_closed()); + } + + #[test] + fn test_conn_write_empty_chunk() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + + assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(Vec::new().into()) }).unwrap().is_ready()); + conn.start_send(Frame::Body { chunk: Some(vec![b'a'].into()) }).unwrap_err(); + } + */ +} diff --git a/third_party/rust/hyper/src/proto/h1/date.rs b/third_party/rust/hyper/src/proto/h1/date.rs new file mode 100644 index 0000000000..3e972d6e00 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/date.rs @@ -0,0 +1,82 @@ +use std::cell::RefCell; +use std::fmt::{self, Write}; +use std::str; + +use http::header::HeaderValue; +use time::{self, Duration}; + +// "Sun, 06 Nov 1994 08:49:37 GMT".len() +pub const DATE_VALUE_LENGTH: usize = 29; + +pub fn extend(dst: &mut Vec<u8>) { + CACHED.with(|cache| { + dst.extend_from_slice(cache.borrow().buffer()); + }) +} + +pub fn update() { + CACHED.with(|cache| { + cache.borrow_mut().check(); + }) +} + +pub(crate) fn update_and_header_value() -> HeaderValue { + CACHED.with(|cache| { + let mut cache = cache.borrow_mut(); + cache.check(); + HeaderValue::from_bytes(cache.buffer()).expect("Date format should be valid HeaderValue") + }) +} + +struct CachedDate { + bytes: [u8; DATE_VALUE_LENGTH], + pos: usize, + next_update: time::Timespec, +} + +thread_local!(static CACHED: RefCell<CachedDate> = RefCell::new(CachedDate::new())); + +impl CachedDate { + fn new() -> Self { + let mut cache = CachedDate { + bytes: [0; DATE_VALUE_LENGTH], + pos: 0, + next_update: time::Timespec::new(0, 0), + }; + cache.update(time::get_time()); + cache + } + + fn buffer(&self) -> &[u8] { + &self.bytes[..] + } + + fn check(&mut self) { + let now = time::get_time(); + if now > self.next_update { + self.update(now); + } + } + + fn update(&mut self, now: time::Timespec) { + self.pos = 0; + let _ = write!(self, "{}", time::at_utc(now).rfc822()); + debug_assert!(self.pos == DATE_VALUE_LENGTH); + self.next_update = now + Duration::seconds(1); + self.next_update.nsec = 0; + } +} + +impl fmt::Write for CachedDate { + fn write_str(&mut self, s: &str) -> fmt::Result { + let len = s.len(); + self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); + self.pos += len; + Ok(()) + } +} + +#[test] +fn test_date_len() { + assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); +} diff --git a/third_party/rust/hyper/src/proto/h1/decode.rs b/third_party/rust/hyper/src/proto/h1/decode.rs new file mode 100644 index 0000000000..beaf9aff7a --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/decode.rs @@ -0,0 +1,674 @@ +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::usize; + +use bytes::Bytes; + +use crate::common::{task, Poll}; + +use super::io::MemRead; +use super::DecodedLength; + +use self::Kind::{Chunked, Eof, Length}; + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Clone, PartialEq)] +pub struct Decoder { + kind: Kind, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// The bool tracks when EOF is seen on the transport. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof(bool), +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl Decoder { + // constructors + + pub fn length(x: u64) -> Decoder { + Decoder { + kind: Kind::Length(x), + } + } + + pub fn chunked() -> Decoder { + Decoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub fn eof() -> Decoder { + Decoder { + kind: Kind::Eof(false), + } + } + + pub(super) fn new(len: DecodedLength) -> Self { + match len { + DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CLOSE_DELIMITED => Decoder::eof(), + length => Decoder::length(length.danger_len()), + } + } + + // methods + + pub fn is_eof(&self) -> bool { + match self.kind { + Length(0) | Chunked(ChunkedState::End, _) | Eof(true) => true, + _ => false, + } + } + + pub fn decode<R: MemRead>( + &mut self, + cx: &mut task::Context<'_>, + body: &mut R, + ) -> Poll<Result<Bytes, io::Error>> { + trace!("decode; state={:?}", self.kind); + match self.kind { + Length(ref mut remaining) => { + if *remaining == 0 { + Poll::Ready(Ok(Bytes::new())) + } else { + let to_read = *remaining as usize; + let buf = ready!(body.read_mem(cx, to_read))?; + let num = buf.as_ref().len() as u64; + if num > *remaining { + *remaining = 0; + } else if num == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } else { + *remaining -= num; + } + Poll::Ready(Ok(buf)) + } + } + Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = ready!(state.step(cx, body, size, &mut buf))?; + if *state == ChunkedState::End { + trace!("end of chunked"); + return Poll::Ready(Ok(Bytes::new())); + } + if let Some(buf) = buf { + return Poll::Ready(Ok(buf)); + } + } + } + Eof(ref mut is_eof) => { + if *is_eof { + Poll::Ready(Ok(Bytes::new())) + } else { + // 8192 chosen because its about 2 packets, there probably + // won't be that much available, so don't have MemReaders + // allocate buffers to big + body.read_mem(cx, 8192).map_ok(|slice| { + *is_eof = slice.is_empty(); + slice + }) + } + } + } + } + + #[cfg(test)] + async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> { + futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await + } +} + +impl fmt::Debug for Decoder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.kind, f) + } +} + +macro_rules! byte ( + ($rdr:ident, $cx:expr) => ({ + let buf = ready!($rdr.read_mem($cx, 1))?; + if !buf.is_empty() { + buf[0] + } else { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, + "unexpected EOF during chunk size line"))); + } + }) +); + +impl ChunkedState { + fn step<R: MemRead>( + &self, + cx: &mut task::Context<'_>, + body: &mut R, + size: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(cx, body, size), + SizeLws => ChunkedState::read_size_lws(cx, body), + Extension => ChunkedState::read_extension(cx, body), + SizeLf => ChunkedState::read_size_lf(cx, body, *size), + Body => ChunkedState::read_body(cx, body, size, buf), + BodyCr => ChunkedState::read_body_cr(cx, body), + BodyLf => ChunkedState::read_body_lf(cx, body), + EndCr => ChunkedState::read_end_cr(cx, body), + EndLf => ChunkedState::read_end_lf(cx, body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + fn read_size<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: &mut u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Read chunk hex size"); + let radix = 16; + match byte!(rdr, cx) { + b @ b'0'..=b'9' => { + *size *= radix; + *size += (b - b'0') as u64; + } + b @ b'a'..=b'f' => { + *size *= radix; + *size += (b + 10 - b'a') as u64; + } + b @ b'A'..=b'F' => { + *size *= radix; + *size += (b + 10 - b'A') as u64; + } + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + } + Poll::Ready(Ok(ChunkedState::Size)) + } + fn read_size_lws<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_size_lws"); + match byte!(rdr, cx) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_extension"); + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunk size is {:?}", size); + match byte!(rdr, cx) { + b'\n' => { + if size == 0 { + Poll::Ready(Ok(ChunkedState::EndCr)) + } else { + debug!("incoming chunked header: {0:#X} ({0} bytes)", size); + Poll::Ready(Ok(ChunkedState::Body)) + } + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + rem: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunked read, remaining={:?}", rem); + + // cap remaining bytes at the max capacity of usize + let rem_cap = match *rem { + r if r > usize::MAX as u64 => usize::MAX, + r => r as usize, + }; + + let to_read = rem_cap; + let slice = ready!(rdr.read_mem(cx, to_read))?; + let count = slice.len(); + + if count == 0 { + *rem = 0; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } + *buf = Some(slice); + *rem -= count as u64; + + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + fn read_body_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + + fn read_end_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + ))), + } + } + fn read_end_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[derive(Debug)] +struct IncompleteBody; + +impl fmt::Display for IncompleteBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "end of file before message length reached") + } +} + +impl StdError for IncompleteBody {} + +#[cfg(test)] +mod tests { + use super::*; + use std::pin::Pin; + use std::time::Duration; + use tokio::io::AsyncRead; + + impl<'a> MemRead for &'a [u8] { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + if n > 0 { + let (a, b) = self.split_at(n); + let buf = Bytes::copy_from_slice(a); + *self = b; + Poll::Ready(Ok(buf)) + } else { + Poll::Ready(Ok(Bytes::new())) + } + } + } + + impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let mut v = vec![0; len]; + let n = ready!(Pin::new(self).poll_read(cx, &mut v)?); + Poll::Ready(Ok(Bytes::copy_from_slice(&v[..n]))) + } + } + + #[cfg(feature = "nightly")] + impl MemRead for Bytes { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + let ret = self.split_to(n); + Poll::Ready(Ok(ret)) + } + } + + /* + use std::io; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use futures::{Async, Poll}; + use bytes::{BytesMut, Bytes}; + use crate::mock::AsyncIo; + */ + + #[tokio::test] + async fn test_read_chunk_size() { + use std::io::ErrorKind::{InvalidInput, UnexpectedEof}; + + async fn read(s: &str) -> u64 { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + let desc = format!("read_size failed for {:?}", s); + state = result.expect(desc.as_str()); + if state == ChunkedState::Body || state == ChunkedState::EndCr { + break; + } + } + size + } + + async fn read_err(s: &str, expected_err: io::ErrorKind) { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + state = match result { + Ok(s) => s, + Err(e) => { + assert!( + expected_err == e.kind(), + "Reading {:?}, expected {:?}, but got {:?}", + s, + expected_err, + e.kind() + ); + return; + } + }; + if state == ChunkedState::Body || state == ChunkedState::End { + panic!(format!("Was Ok. Expected Err for {:?}", s)); + } + } + } + + assert_eq!(1, read("1\r\n").await); + assert_eq!(1, read("01\r\n").await); + assert_eq!(0, read("0\r\n").await); + assert_eq!(0, read("00\r\n").await); + assert_eq!(10, read("A\r\n").await); + assert_eq!(10, read("a\r\n").await); + assert_eq!(255, read("Ff\r\n").await); + assert_eq!(255, read("Ff \r\n").await); + // Missing LF or CRLF + read_err("F\rF", InvalidInput).await; + read_err("F", UnexpectedEof).await; + // Invalid hex digit + read_err("X\r\n", InvalidInput).await; + read_err("1X\r\n", InvalidInput).await; + read_err("-\r\n", InvalidInput).await; + read_err("-1\r\n", InvalidInput).await; + // Acceptable (if not fully valid) extensions do not influence the size + assert_eq!(1, read("1;extension\r\n").await); + assert_eq!(10, read("a;ext name=value\r\n").await); + assert_eq!(1, read("1;extension;extension2\r\n").await); + assert_eq!(1, read("1;;; ;\r\n").await); + assert_eq!(2, read("2; extension...\r\n").await); + assert_eq!(3, read("3 ; extension=123\r\n").await); + assert_eq!(3, read("3 ;\r\n").await); + assert_eq!(3, read("3 ; \r\n").await); + // Invalid extensions cause an error + read_err("1 invalid extension\r\n", InvalidInput).await; + read_err("1 A\r\n", InvalidInput).await; + read_err("1;no CRLF", UnexpectedEof).await; + } + + #[tokio::test] + async fn test_read_sized_early_eof() { + let mut bytes = &b"foo bar"[..]; + let mut decoder = Decoder::length(10); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_early_eof() { + let mut bytes = &b"\ + 9\r\n\ + foo bar\ + "[..]; + let mut decoder = Decoder::chunked(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_single_read() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; + let buf = Decoder::chunked() + .decode_fut(&mut mock_buf) + .await + .expect("decode"); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + } + + #[tokio::test] + async fn test_read_chunked_after_eof() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; + let mut decoder = Decoder::chunked(); + + // normal read + let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + + // eof read + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + + // ensure read after eof also returns eof + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + } + + // perform an async read using a custom buffer size and causing a blocking + // read at the specified byte + async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { + let mut outs = Vec::new(); + + let mut ins = if block_at == 0 { + tokio_test::io::Builder::new() + .wait(Duration::from_millis(10)) + .read(content) + .build() + } else { + tokio_test::io::Builder::new() + .read(&content[..block_at]) + .wait(Duration::from_millis(10)) + .read(&content[block_at..]) + .build() + }; + + let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + + loop { + let buf = decoder + .decode_fut(&mut ins) + .await + .expect("unexpected decode error"); + if buf.is_empty() { + break; // eof + } + outs.extend(buf.as_ref()); + } + + String::from_utf8(outs).expect("decode String") + } + + // iterate over the different ways that this async read could go. + // tests blocking a read at each byte along the content - The shotgun approach + async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + let content_len = content.len(); + for block_at in 0..content_len { + let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await; + assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at); + } + } + + #[tokio::test] + async fn test_read_length_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::length(content.len() as u64)).await; + } + + #[tokio::test] + async fn test_read_chunked_async() { + let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; + let expected = "foobar"; + all_async_cases(content, expected, Decoder::chunked()).await; + } + + #[tokio::test] + async fn test_read_eof_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::eof()).await; + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_chunked_1kb(b: &mut test::Bencher) { + let mut rt = new_runtime(); + + const LEN: usize = 1024; + let mut vec = Vec::new(); + vec.extend(format!("{:x}\r\n", LEN).as_bytes()); + vec.extend(&[0; LEN][..]); + vec.extend(b"\r\n"); + let content = Bytes::from(vec); + + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::chunked(); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_length_1kb(b: &mut test::Bencher) { + let mut rt = new_runtime(); + + const LEN: usize = 1024; + let content = Bytes::from(&[0; LEN][..]); + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::length(LEN as u64); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + fn new_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .expect("rt build") + } +} diff --git a/third_party/rust/hyper/src/proto/h1/dispatch.rs b/third_party/rust/hyper/src/proto/h1/dispatch.rs new file mode 100644 index 0000000000..84ee412c3c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/dispatch.rs @@ -0,0 +1,702 @@ +use std::error::Error as StdError; + +use bytes::{Buf, Bytes}; +use http::{Request, Response, StatusCode}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{Http1Transaction, Wants}; +use crate::body::{Body, Payload}; +use crate::common::{task, Future, Never, Pin, Poll, Unpin}; +use crate::proto::{ + BodyLength, Conn, DecodedLength, Dispatched, MessageHead, RequestHead, RequestLine, + ResponseHead, +}; +use crate::service::HttpService; + +pub(crate) struct Dispatcher<D, Bs: Payload, I, T> { + conn: Conn<I, Bs::Data, T>, + dispatch: D, + body_tx: Option<crate::body::Sender>, + body_rx: Pin<Box<Option<Bs>>>, + is_closing: bool, +} + +pub(crate) trait Dispatch { + type PollItem; + type PollBody; + type PollError; + type RecvItem; + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>; + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>; + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>; + fn should_poll(&self) -> bool; +} + +pub struct Server<S: HttpService<B>, B> { + in_flight: Pin<Box<Option<S::Future>>>, + pub(crate) service: S, +} + +pub struct Client<B> { + callback: Option<crate::client::dispatch::Callback<Request<B>, Response<Body>>>, + rx: ClientRx<B>, + rx_closed: bool, +} + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +impl<D, Bs, I, T> Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: Payload, +{ + pub fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self { + Dispatcher { + conn, + dispatch, + body_tx: None, + body_rx: Box::pin(None), + is_closing: false, + } + } + + pub fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive(); + if self.conn.is_write_closed() { + self.close(); + } + } + + pub fn into_inner(self) -> (I, Bytes, D) { + let (io, buf) = self.conn.into_inner(); + (io, buf, self.dispatch) + } + + /// Run this dispatcher until HTTP says this connection is done, + /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// + /// This is useful for old-style HTTP upgrades, but ignores + /// newer-style upgrade API. + pub(crate) fn poll_without_shutdown( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> + where + Self: Unpin, + { + Pin::new(self).poll_catch(cx, false).map_ok(|ds| { + if let Dispatched::Upgrade(pending) = ds { + pending.manual(); + } + }) + } + + fn poll_catch( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e))?; + Ok(Dispatched::Shutdown) + })) + } + + fn poll_inner( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + T::update_date(); + + ready!(self.poll_loop(cx))?; + + if self.is_done() { + if let Some(pending) = self.conn.pending_upgrade() { + self.conn.take_error()?; + return Poll::Ready(Ok(Dispatched::Upgrade(pending))); + } else if should_shutdown { + ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?; + } + self.conn.take_error()?; + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Pending + } + } + + fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Limit the looping on this connection, in case it is ready far too + // often, so that other futures don't starve. + // + // 16 was chosen arbitrarily, as that is number of pipelined requests + // benchmarks often use. Perhaps it should be a config option instead. + for _ in 0..16 { + let _ = self.poll_read(cx)?; + let _ = self.poll_write(cx)?; + let _ = self.poll_flush(cx)?; + + // This could happen if reading paused before blocking on IO, + // such as getting to the end of a framed message, but then + // writing/flushing set the state back to Init. In that case, + // if the read buffer still had bytes, we'd want to try poll_read + // again, or else we wouldn't ever be woken up again. + // + // Using this instead of task::current() and notify() inside + // the Conn is noticeably faster in pipelined benchmarks. + if !self.conn.wants_read_again() { + //break; + return Poll::Ready(Ok(())); + } + } + + trace!("poll_loop yielding (self = {:p})", self); + + task::yield_now(cx).map(|never| match never {}) + } + + fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.conn.can_read_head() { + ready!(self.poll_read_head(cx))?; + } else if let Some(mut body) = self.body_tx.take() { + if self.conn.can_read_body() { + match body.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Err(_canceled)) => { + // user doesn't care about the body + // so we should stop reading + trace!("body receiver dropped before eof, draining or closing"); + self.conn.poll_drain_or_close_read(cx); + continue; + } + } + match self.conn.poll_read_body(cx) { + Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + }, + Poll::Ready(None) => { + // just drop, the body will close automatically + } + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => { + body.send_error(crate::Error::new_body(e)); + } + } + } else { + // just drop, the body will close automatically + } + } else { + return self.conn.poll_read_keep_alive(cx); + } + } + } + + fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // can dispatch receive, or does it still care about, an incoming message? + match ready!(self.dispatch.poll_ready(cx)) { + Ok(()) => (), + Err(()) => { + trace!("dispatch no longer receiving messages"); + self.close(); + return Poll::Ready(Ok(())); + } + } + // dispatch is ready for a message, try to read one + match ready!(self.conn.poll_read_head(cx)) { + Some(Ok((head, body_len, wants))) => { + let mut body = match body_len { + DecodedLength::ZERO => Body::empty(), + other => { + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); + self.body_tx = Some(tx); + rx + } + }; + if wants.contains(Wants::UPGRADE) { + body.set_on_upgrade(self.conn.on_upgrade()); + } + self.dispatch.recv_msg(Ok((head, body)))?; + Poll::Ready(Ok(())) + } + Some(Err(err)) => { + debug!("read_head error: {}", err); + self.dispatch.recv_msg(Err(err))?; + // if here, the dispatcher gave the user the error + // somewhere else. we still need to shutdown, but + // not as a second error. + self.close(); + Poll::Ready(Ok(())) + } + None => { + // read eof, the write side will have been closed too unless + // allow_read_close was set to true, in which case just do + // nothing... + debug_assert!(self.conn.is_read_closed()); + if self.conn.is_write_closed() { + self.close(); + } + Poll::Ready(Ok(())) + } + } + } + + fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.body_rx.is_none() + && self.conn.can_write_head() + && self.dispatch.should_poll() + { + if let Some(msg) = ready!(self.dispatch.poll_msg(cx)) { + let (head, mut body) = msg.map_err(crate::Error::new_user_service)?; + + // Check if the body knows its full data immediately. + // + // If so, we can skip a bit of bookkeeping that streaming + // bodies need to do. + if let Some(full) = crate::body::take_full_data(&mut body) { + self.conn.write_full_msg(head, full); + return Poll::Ready(Ok(())); + } + + let body_type = if body.is_end_stream() { + self.body_rx.set(None); + None + } else { + let btype = body + .size_hint() + .exact() + .map(BodyLength::Known) + .or_else(|| Some(BodyLength::Unknown)); + self.body_rx.set(Some(body)); + btype + }; + self.conn.write_head(head, body_type); + } else { + self.close(); + return Poll::Ready(Ok(())); + } + } else if !self.conn.can_buffer_body() { + ready!(self.poll_flush(cx))?; + } else { + // A new scope is needed :( + if let (Some(mut body), clear_body) = + OptGuard::new(self.body_rx.as_mut()).guard_mut() + { + debug_assert!(!*clear_body, "opt guard defaults to keeping body"); + if !self.conn.can_write_body() { + trace!( + "no more write body allowed, user body is_end_stream = {}", + body.is_end_stream(), + ); + *clear_body = true; + continue; + } + + let item = ready!(body.as_mut().poll_data(cx)); + if let Some(item) = item { + let chunk = item.map_err(|e| { + *clear_body = true; + crate::Error::new_user_body(e) + })?; + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body(); + } else { + self.conn.write_body_and_end(chunk); + } + } else { + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); + } + } else { + *clear_body = true; + self.conn.end_body(); + } + } else { + return Poll::Pending; + } + } + } + } + + fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.conn.poll_flush(cx).map_err(|err| { + debug!("error writing: {}", err); + crate::Error::new_body_write(err) + }) + } + + fn close(&mut self) { + self.is_closing = true; + self.conn.close_read(); + self.conn.close_write(); + } + + fn is_done(&self) -> bool { + if self.is_closing { + return true; + } + + let read_done = self.conn.is_read_closed(); + + if !T::should_read_first() && read_done { + // a client that cannot read may was well be done. + true + } else { + let write_done = self.conn.is_write_closed() + || (!self.dispatch.should_poll() && self.body_rx.is_none()); + read_done && write_done + } + } +} + +impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: Payload, +{ + type Output = crate::Result<Dispatched>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_catch(cx, true) + } +} + +// ===== impl OptGuard ===== + +/// A drop guard to allow a mutable borrow of an Option while being able to +/// set whether the `Option` should be cleared on drop. +struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool); + +impl<'a, T> OptGuard<'a, T> { + fn new(pin: Pin<&'a mut Option<T>>) -> Self { + OptGuard(pin, false) + } + + fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) { + (self.0.as_mut().as_pin_mut(), &mut self.1) + } +} + +impl<'a, T> Drop for OptGuard<'a, T> { + fn drop(&mut self) { + if self.1 { + self.0.set(None); + } + } +} + +// ===== impl Server ===== + +impl<S, B> Server<S, B> +where + S: HttpService<B>, +{ + pub fn new(service: S) -> Server<S, B> { + Server { + in_flight: Box::pin(None), + service, + } + } + + pub fn into_service(self) -> S { + self.service + } +} + +// Service is never pinned +impl<S: HttpService<B>, B> Unpin for Server<S, B> {} + +impl<S, Bs> Dispatch for Server<S, Body> +where + S: HttpService<Body, ResBody = Bs>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bs: Payload, +{ + type PollItem = MessageHead<StatusCode>; + type PollBody = Bs; + type PollError = S::Error; + type RecvItem = RequestHead; + + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> { + let ret = if let Some(ref mut fut) = self.in_flight.as_mut().as_pin_mut() { + let resp = ready!(fut.as_mut().poll(cx)?); + let (parts, body) = resp.into_parts(); + let head = MessageHead { + version: parts.version, + subject: parts.status, + headers: parts.headers, + }; + Poll::Ready(Some(Ok((head, body)))) + } else { + unreachable!("poll_msg shouldn't be called if no inflight"); + }; + + // Since in_flight finished, remove it + self.in_flight.set(None); + ret + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + let (msg, body) = msg?; + let mut req = Request::new(body); + *req.method_mut() = msg.subject.0; + *req.uri_mut() = msg.subject.1; + *req.headers_mut() = msg.headers; + *req.version_mut() = msg.version; + let fut = self.service.call(req); + self.in_flight.set(Some(fut)); + Ok(()) + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + if self.in_flight.is_some() { + Poll::Pending + } else { + self.service.poll_ready(cx).map_err(|_e| { + // FIXME: return error value. + trace!("service closed"); + }) + } + } + + fn should_poll(&self) -> bool { + self.in_flight.is_some() + } +} + +// ===== impl Client ===== + +impl<B> Client<B> { + pub fn new(rx: ClientRx<B>) -> Client<B> { + Client { + callback: None, + rx, + rx_closed: false, + } + } +} + +impl<B> Dispatch for Client<B> +where + B: Payload, +{ + type PollItem = RequestHead; + type PollBody = B; + type PollError = Never; + type RecvItem = ResponseHead; + + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Never>>> { + debug_assert!(!self.rx_closed); + match self.rx.poll_next(cx) { + Poll::Ready(Some((req, mut cb))) => { + // check that future hasn't been canceled already + match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("request canceled"); + Poll::Ready(None) + } + Poll::Pending => { + let (parts, body) = req.into_parts(); + let head = RequestHead { + version: parts.version, + subject: RequestLine(parts.method, parts.uri), + headers: parts.headers, + }; + self.callback = Some(cb); + Poll::Ready(Some(Ok((head, body)))) + } + } + } + Poll::Ready(None) => { + // user has dropped sender handle + trace!("client tx closed"); + self.rx_closed = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + match msg { + Ok((msg, body)) => { + if let Some(cb) = self.callback.take() { + let mut res = Response::new(body); + *res.status_mut() = msg.subject; + *res.headers_mut() = msg.headers; + *res.version_mut() = msg.version; + cb.send(Ok(res)); + Ok(()) + } else { + // Getting here is likely a bug! An error should have happened + // in Conn::require_empty_read() before ever parsing a + // full message! + Err(crate::Error::new_unexpected_message()) + } + } + Err(err) => { + if let Some(cb) = self.callback.take() { + cb.send(Err((err, None))); + Ok(()) + } else if !self.rx_closed { + self.rx.close(); + if let Some((req, cb)) = self.rx.try_recv() { + trace!("canceling queued request with connection error: {}", err); + // in this case, the message was never even started, so it's safe to tell + // the user that the request was completely canceled + cb.send(Err((crate::Error::new_canceled().with(err), Some(req)))); + Ok(()) + } else { + Err(err) + } + } else { + Err(err) + } + } + } + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + match self.callback { + Some(ref mut cb) => match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("callback receiver has dropped"); + Poll::Ready(Err(())) + } + Poll::Pending => Poll::Ready(Ok(())), + }, + None => Poll::Ready(Err(())), + } + } + + fn should_poll(&self) -> bool { + self.callback.is_none() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::h1::ClientTransaction; + use std::time::Duration; + + #[test] + fn client_read_bytes_before_writing_request() { + let _ = pretty_env_logger::try_init(); + + tokio_test::task::spawn(()).enter(|cx, _| { + let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle(); + + // Block at 0 for now, but we will release this response before + // the request is ready to write later... + //let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 0); + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + // First poll is needed to allow tx to send... + assert!(Pin::new(&mut dispatcher).poll(cx).is_pending()); + + // Unblock our IO, which has a response before we've sent request! + // + handle.read(b"HTTP/1.1 200 OK\r\n\r\n"); + + let mut res_rx = tx + .try_send(crate::Request::new(crate::Body::empty())) + .unwrap(); + + tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx)); + let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx)) + .expect_err("callback should send error"); + + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + other => panic!("expected Canceled, got {:?}", other), + } + }); + } + + #[tokio::test] + async fn body_empty_chunks_ignored() { + let _ = pretty_env_logger::try_init(); + + let io = tokio_test::io::Builder::new() + // no reading or writing, just be blocked for the test... + .wait(Duration::from_secs(5)) + .build(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); + + // First poll is needed to allow tx to send... + assert!(dispatcher.poll().is_pending()); + + let body = { + let (mut tx, body) = crate::Body::channel(); + tx.try_send_data("".into()).unwrap(); + body + }; + + let _res_rx = tx.try_send(crate::Request::new(body)).unwrap(); + + // Ensure conn.write_body wasn't called with the empty chunk. + // If it is, it will trigger an assertion. + assert!(dispatcher.poll().is_pending()); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/encode.rs b/third_party/rust/hyper/src/proto/h1/encode.rs new file mode 100644 index 0000000000..95b0d82b67 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/encode.rs @@ -0,0 +1,418 @@ +use std::fmt; +use std::io::IoSlice; + +use bytes::buf::ext::{BufExt, Chain, Take}; +use bytes::Buf; + +use super::io::WriteBuf; + +type StaticBuf = &'static [u8]; + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone, PartialEq)] +pub struct Encoder { + kind: Kind, + is_last: bool, +} + +#[derive(Debug)] +pub struct EncodedBuf<B> { + kind: BufKind<B>, +} + +#[derive(Debug)] +pub struct NotEof; + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked, + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when neither Content-Length nor Chunked encoding is set. + /// + /// This is mostly only used with HTTP/1.0 with a length. This kind requires + /// the connection to be closed when the body is finished. + CloseDelimited, +} + +#[derive(Debug)] +enum BufKind<B> { + Exact(B), + Limited(Take<B>), + Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), + ChunkedEnd(StaticBuf), +} + +impl Encoder { + fn new(kind: Kind) -> Encoder { + Encoder { + kind, + is_last: false, + } + } + pub fn chunked() -> Encoder { + Encoder::new(Kind::Chunked) + } + + pub fn length(len: u64) -> Encoder { + Encoder::new(Kind::Length(len)) + } + + pub fn close_delimited() -> Encoder { + Encoder::new(Kind::CloseDelimited) + } + + pub fn is_eof(&self) -> bool { + match self.kind { + Kind::Length(0) => true, + _ => false, + } + } + + pub fn set_last(mut self, is_last: bool) -> Self { + self.is_last = is_last; + self + } + + pub fn is_last(&self) -> bool { + self.is_last + } + + pub fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { + match self.kind { + Kind::Length(0) => Ok(None), + Kind::Chunked => Ok(Some(EncodedBuf { + kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), + })), + _ => Err(NotEof), + } + } + + pub fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + let kind = match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n" as &'static [u8]); + BufKind::Chunked(buf) + } + Kind::Length(ref mut remaining) => { + trace!("sized write, len = {}", len); + if len as u64 > *remaining { + let limit = *remaining as usize; + *remaining = 0; + BufKind::Limited(msg.take(limit)) + } else { + *remaining -= len as u64; + BufKind::Exact(msg) + } + } + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + BufKind::Exact(msg) + } + }; + EncodedBuf { kind } + } + + pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + !self.is_last + } + Kind::Length(remaining) => { + use std::cmp::Ordering; + + trace!("sized write, len = {}", len); + match (len as u64).cmp(&remaining) { + Ordering::Equal => { + dst.buffer(msg); + !self.is_last + } + Ordering::Greater => { + dst.buffer(msg.take(remaining as usize)); + !self.is_last + } + Ordering::Less => { + dst.buffer(msg); + false + } + } + } + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + dst.buffer(msg); + false + } + } + } + + /// Encodes the full body, without verifying the remaining length matches. + /// + /// This is used in conjunction with Payload::__hyper_full_data(), which + /// means we can trust that the buf has the correct size (the buf itself + /// was checked to make the headers). + pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) + where + B: Buf, + { + debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); + debug_assert!( + match self.kind { + Kind::Length(len) => len == msg.remaining() as u64, + _ => true, + }, + "danger_full_buf length mismatches" + ); + + match self.kind { + Kind::Chunked => { + let len = msg.remaining(); + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + } + _ => { + dst.buffer(msg); + } + } + } +} + +impl<B> Buf for EncodedBuf<B> +where + B: Buf, +{ + #[inline] + fn remaining(&self) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.remaining(), + BufKind::Limited(ref b) => b.remaining(), + BufKind::Chunked(ref b) => b.remaining(), + BufKind::ChunkedEnd(ref b) => b.remaining(), + } + } + + #[inline] + fn bytes(&self) -> &[u8] { + match self.kind { + BufKind::Exact(ref b) => b.bytes(), + BufKind::Limited(ref b) => b.bytes(), + BufKind::Chunked(ref b) => b.bytes(), + BufKind::ChunkedEnd(ref b) => b.bytes(), + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match self.kind { + BufKind::Exact(ref mut b) => b.advance(cnt), + BufKind::Limited(ref mut b) => b.advance(cnt), + BufKind::Chunked(ref mut b) => b.advance(cnt), + BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.bytes_vectored(dst), + BufKind::Limited(ref b) => b.bytes_vectored(dst), + BufKind::Chunked(ref b) => b.bytes_vectored(dst), + BufKind::ChunkedEnd(ref b) => b.bytes_vectored(dst), + } + } +} + +#[cfg(target_pointer_width = "32")] +const USIZE_BYTES: usize = 4; + +#[cfg(target_pointer_width = "64")] +const USIZE_BYTES: usize = 8; + +// each byte will become 2 hex +const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; + +#[derive(Clone, Copy)] +struct ChunkSize { + bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], + pos: u8, + len: u8, +} + +impl ChunkSize { + fn new(len: usize) -> ChunkSize { + use std::fmt::Write; + let mut size = ChunkSize { + bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], + pos: 0, + len: 0, + }; + write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); + size + } +} + +impl Buf for ChunkSize { + #[inline] + fn remaining(&self) -> usize { + (self.len - self.pos).into() + } + + #[inline] + fn bytes(&self) -> &[u8] { + &self.bytes[self.pos.into()..self.len.into()] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining()); + self.pos += cnt as u8; // just asserted cnt fits in u8 + } +} + +impl fmt::Debug for ChunkSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkSize") + .field("bytes", &&self.bytes[..self.len.into()]) + .field("pos", &self.pos) + .finish() + } +} + +impl fmt::Write for ChunkSize { + fn write_str(&mut self, num: &str) -> fmt::Result { + use std::io::Write; + (&mut self.bytes[self.len.into()..]) + .write_all(num.as_bytes()) + .expect("&mut [u8].write() cannot error"); + self.len += num.len() as u8; // safe because bytes is never bigger than 256 + Ok(()) + } +} + +impl<B: Buf> From<B> for EncodedBuf<B> { + fn from(buf: B) -> Self { + EncodedBuf { + kind: BufKind::Exact(buf), + } + } +} + +impl<B: Buf> From<Take<B>> for EncodedBuf<B> { + fn from(buf: Take<B>) -> Self { + EncodedBuf { + kind: BufKind::Limited(buf), + } + } +} + +impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { + fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { + EncodedBuf { + kind: BufKind::Chunked(buf), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::BufMut; + + use super::super::io::Cursor; + use super::Encoder; + + #[test] + fn chunked() { + let mut encoder = Encoder::chunked(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + assert_eq!(dst, b"7\r\nfoo bar\r\n"); + + let msg2 = b"baz quux herp".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); + + let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); + dst.put(end); + + assert_eq!( + dst, + b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() + ); + } + + #[test] + fn length() { + let max_len = 8; + let mut encoder = Encoder::length(max_len as u64); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst.len(), max_len); + assert_eq!(dst, b"foo barb"); + assert!(encoder.is_eof()); + assert!(encoder.end::<()>().unwrap().is_none()); + } + + #[test] + fn eof() { + let mut encoder = Encoder::close_delimited(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"foo barbaz"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/io.rs b/third_party/rust/hyper/src/proto/h1/io.rs new file mode 100644 index 0000000000..00f4f64f47 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/io.rs @@ -0,0 +1,907 @@ +use std::cell::Cell; +use std::cmp; +use std::fmt; +use std::io::{self, IoSlice}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{Http1Transaction, ParseContext, ParsedMessage}; +use crate::common::buf::BufList; +use crate::common::{task, Pin, Poll, Unpin}; + +/// The initial buffer size allocated before trying to read from IO. +pub(crate) const INIT_BUFFER_SIZE: usize = 8192; + +/// The minimum value that can be set to max buffer size. +pub const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; + +/// The default maximum read buffer size. If the buffer gets this big and +/// a message is still not complete, a `TooLarge` error is triggered. +// Note: if this changes, update server::conn::Http::max_buf_size docs. +pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +/// The maximum number of distinct `Buf`s to hold in a list before requiring +/// a flush. Only affects when the buffer strategy is to queue buffers. +/// +/// Note that a flush can happen before reaching the maximum. This simply +/// forces a flush if the queue gets this big. +const MAX_BUF_LIST_BUFFERS: usize = 16; + +pub struct Buffered<T, B> { + flush_pipeline: bool, + io: T, + read_blocked: bool, + read_buf: BytesMut, + read_buf_strategy: ReadStrategy, + write_buf: WriteBuf<B>, +} + +impl<T, B> fmt::Debug for Buffered<T, B> +where + B: Buf, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Buffered") + .field("read_buf", &self.read_buf) + .field("write_buf", &self.write_buf) + .finish() + } +} + +impl<T, B> Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + pub fn new(io: T) -> Buffered<T, B> { + Buffered { + flush_pipeline: false, + io, + read_blocked: false, + read_buf: BytesMut::with_capacity(0), + read_buf_strategy: ReadStrategy::default(), + write_buf: WriteBuf::new(), + } + } + + pub fn set_flush_pipeline(&mut self, enabled: bool) { + debug_assert!(!self.write_buf.has_remaining()); + self.flush_pipeline = enabled; + if enabled { + self.set_write_strategy_flatten(); + } + } + + pub fn set_max_buf_size(&mut self, max: usize) { + assert!( + max >= MINIMUM_MAX_BUFFER_SIZE, + "The max_buf_size cannot be smaller than {}.", + MINIMUM_MAX_BUFFER_SIZE, + ); + self.read_buf_strategy = ReadStrategy::with_max(max); + self.write_buf.max_buf_size = max; + } + + pub fn set_read_buf_exact_size(&mut self, sz: usize) { + self.read_buf_strategy = ReadStrategy::Exact(sz); + } + + pub fn set_write_strategy_flatten(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Flatten); + } + + pub fn read_buf(&self) -> &[u8] { + self.read_buf.as_ref() + } + + #[cfg(test)] + #[cfg(feature = "nightly")] + pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + + /// Return the "allocated" available space, not the potential space + /// that could be allocated in the future. + fn read_buf_remaining_mut(&self) -> usize { + self.read_buf.capacity() - self.read_buf.len() + } + + pub fn headers_buf(&mut self) -> &mut Vec<u8> { + let buf = self.write_buf.headers_mut(); + &mut buf.bytes + } + + pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { + &mut self.write_buf + } + + pub fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { + self.write_buf.buffer(buf) + } + + pub fn can_buffer(&self) -> bool { + self.flush_pipeline || self.write_buf.can_buffer() + } + + pub fn consume_leading_lines(&mut self) { + if !self.read_buf.is_empty() { + let mut i = 0; + while i < self.read_buf.len() { + match self.read_buf[i] { + b'\r' | b'\n' => i += 1, + _ => break, + } + } + self.read_buf.advance(i); + } + } + + pub(super) fn parse<S>( + &mut self, + cx: &mut task::Context<'_>, + parse_ctx: ParseContext<'_>, + ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> + where + S: Http1Transaction, + { + loop { + match S::parse( + &mut self.read_buf, + ParseContext { + cached_headers: parse_ctx.cached_headers, + req_method: parse_ctx.req_method, + }, + )? { + Some(msg) => { + debug!("parsed {} headers", msg.head.headers.len()); + return Poll::Ready(Ok(msg)); + } + None => { + let max = self.read_buf_strategy.max(); + if self.read_buf.len() >= max { + debug!("max_buf_size ({}) reached, closing", max); + return Poll::Ready(Err(crate::Error::new_too_large())); + } + } + } + if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { + trace!("parse eof"); + return Poll::Ready(Err(crate::Error::new_incomplete())); + } + } + } + + pub fn poll_read_from_io(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + self.read_blocked = false; + let next = self.read_buf_strategy.next(); + if self.read_buf_remaining_mut() < next { + self.read_buf.reserve(next); + } + match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) { + Poll::Ready(Ok(n)) => { + debug!("read {} bytes", n); + self.read_buf_strategy.record(n); + Poll::Ready(Ok(n)) + } + Poll::Pending => { + self.read_blocked = true; + Poll::Pending + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + + pub fn into_inner(self) -> (T, Bytes) { + (self.io, self.read_buf.freeze()) + } + + pub fn io_mut(&mut self) -> &mut T { + &mut self.io + } + + pub fn is_read_blocked(&self) -> bool { + self.read_blocked + } + + pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + if self.flush_pipeline && !self.read_buf.is_empty() { + Poll::Ready(Ok(())) + } else if self.write_buf.remaining() == 0 { + Pin::new(&mut self.io).poll_flush(cx) + } else { + if let WriteStrategy::Flatten = self.write_buf.strategy { + return self.poll_flush_flattened(cx); + } + loop { + let n = + ready!(Pin::new(&mut self.io).poll_write_buf(cx, &mut self.write_buf.auto()))?; + debug!("flushed {} bytes", n); + if self.write_buf.remaining() == 0 { + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + } + + /// Specialized version of `flush` when strategy is Flatten. + /// + /// Since all buffered bytes are flattened into the single headers buffer, + /// that skips some bookkeeping around using multiple buffers. + fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + loop { + let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.bytes()))?; + debug!("flushed {} bytes", n); + self.write_buf.headers.advance(n); + if self.write_buf.headers.remaining() == 0 { + self.write_buf.headers.reset(); + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + + #[cfg(test)] + fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { + futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) + } +} + +// The `B` is a `Buf`, we never project a pin to it +impl<T: Unpin, B> Unpin for Buffered<T, B> {} + +// TODO: This trait is old... at least rename to PollBytes or something... +pub trait MemRead { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; +} + +impl<T, B> MemRead for Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + if !self.read_buf.is_empty() { + let n = std::cmp::min(len, self.read_buf.len()); + Poll::Ready(Ok(self.read_buf.split_to(n).freeze())) + } else { + let n = ready!(self.poll_read_from_io(cx))?; + Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze())) + } + } +} + +#[derive(Clone, Copy, Debug)] +enum ReadStrategy { + Adaptive { + decrease_now: bool, + next: usize, + max: usize, + }, + Exact(usize), +} + +impl ReadStrategy { + fn with_max(max: usize) -> ReadStrategy { + ReadStrategy::Adaptive { + decrease_now: false, + next: INIT_BUFFER_SIZE, + max, + } + } + + fn next(&self) -> usize { + match *self { + ReadStrategy::Adaptive { next, .. } => next, + ReadStrategy::Exact(exact) => exact, + } + } + + fn max(&self) -> usize { + match *self { + ReadStrategy::Adaptive { max, .. } => max, + ReadStrategy::Exact(exact) => exact, + } + } + + fn record(&mut self, bytes_read: usize) { + if let ReadStrategy::Adaptive { + ref mut decrease_now, + ref mut next, + max, + .. + } = *self + { + if bytes_read >= *next { + *next = cmp::min(incr_power_of_two(*next), max); + *decrease_now = false; + } else { + let decr_to = prev_power_of_two(*next); + if bytes_read < decr_to { + if *decrease_now { + *next = cmp::max(decr_to, INIT_BUFFER_SIZE); + *decrease_now = false; + } else { + // Decreasing is a two "record" process. + *decrease_now = true; + } + } else { + // A read within the current range should cancel + // a potential decrease, since we just saw proof + // that we still need this size. + *decrease_now = false; + } + } + } + } +} + +fn incr_power_of_two(n: usize) -> usize { + n.saturating_mul(2) +} + +fn prev_power_of_two(n: usize) -> usize { + // Only way this shift can underflow is if n is less than 4. + // (Which would means `usize::MAX >> 64` and underflowed!) + debug_assert!(n >= 4); + (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 +} + +impl Default for ReadStrategy { + fn default() -> ReadStrategy { + ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) + } +} + +#[derive(Clone)] +pub struct Cursor<T> { + bytes: T, + pos: usize, +} + +impl<T: AsRef<[u8]>> Cursor<T> { + #[inline] + pub(crate) fn new(bytes: T) -> Cursor<T> { + Cursor { bytes, pos: 0 } + } +} + +impl Cursor<Vec<u8>> { + fn reset(&mut self) { + self.pos = 0; + self.bytes.clear(); + } +} + +impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Cursor") + .field("pos", &self.pos) + .field("len", &self.bytes.as_ref().len()) + .finish() + } +} + +impl<T: AsRef<[u8]>> Buf for Cursor<T> { + #[inline] + fn remaining(&self) -> usize { + self.bytes.as_ref().len() - self.pos + } + + #[inline] + fn bytes(&self) -> &[u8] { + &self.bytes.as_ref()[self.pos..] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); + self.pos += cnt; + } +} + +// an internal buffer to collect writes before flushes +pub(super) struct WriteBuf<B> { + /// Re-usable buffer that holds message headers + headers: Cursor<Vec<u8>>, + max_buf_size: usize, + /// Deque of user buffers if strategy is Queue + queue: BufList<B>, + strategy: WriteStrategy, +} + +impl<B: Buf> WriteBuf<B> { + fn new() -> WriteBuf<B> { + WriteBuf { + headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, + queue: BufList::new(), + strategy: WriteStrategy::Auto, + } + } +} + +impl<B> WriteBuf<B> +where + B: Buf, +{ + fn set_strategy(&mut self, strategy: WriteStrategy) { + self.strategy = strategy; + } + + #[inline] + fn auto(&mut self) -> WriteBufAuto<'_, B> { + WriteBufAuto::new(self) + } + + pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { + debug_assert!(buf.has_remaining()); + match self.strategy { + WriteStrategy::Flatten => { + let head = self.headers_mut(); + //perf: This is a little faster than <Vec as BufMut>>::put, + //but accomplishes the same result. + loop { + let adv = { + let slice = buf.bytes(); + if slice.is_empty() { + return; + } + head.bytes.extend_from_slice(slice); + slice.len() + }; + buf.advance(adv); + } + } + WriteStrategy::Auto | WriteStrategy::Queue => { + self.queue.push(buf.into()); + } + } + } + + fn can_buffer(&self) -> bool { + match self.strategy { + WriteStrategy::Flatten => self.remaining() < self.max_buf_size, + WriteStrategy::Auto | WriteStrategy::Queue => { + self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size + } + } + } + + fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { + debug_assert!(!self.queue.has_remaining()); + &mut self.headers + } +} + +impl<B: Buf> fmt::Debug for WriteBuf<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WriteBuf") + .field("remaining", &self.remaining()) + .field("strategy", &self.strategy) + .finish() + } +} + +impl<B: Buf> Buf for WriteBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.headers.remaining() + self.queue.remaining() + } + + #[inline] + fn bytes(&self) -> &[u8] { + let headers = self.headers.bytes(); + if !headers.is_empty() { + headers + } else { + self.queue.bytes() + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + let hrem = self.headers.remaining(); + + match hrem.cmp(&cnt) { + cmp::Ordering::Equal => self.headers.reset(), + cmp::Ordering::Greater => self.headers.advance(cnt), + cmp::Ordering::Less => { + let qcnt = cnt - hrem; + self.headers.reset(); + self.queue.advance(qcnt); + } + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + let n = self.headers.bytes_vectored(dst); + self.queue.bytes_vectored(&mut dst[n..]) + n + } +} + +/// Detects when wrapped `WriteBuf` is used for vectored IO, and +/// adjusts the `WriteBuf` strategy if not. +struct WriteBufAuto<'a, B: Buf> { + bytes_called: Cell<bool>, + bytes_vec_called: Cell<bool>, + inner: &'a mut WriteBuf<B>, +} + +impl<'a, B: Buf> WriteBufAuto<'a, B> { + fn new(inner: &'a mut WriteBuf<B>) -> WriteBufAuto<'a, B> { + WriteBufAuto { + bytes_called: Cell::new(false), + bytes_vec_called: Cell::new(false), + inner, + } + } +} + +impl<'a, B: Buf> Buf for WriteBufAuto<'a, B> { + #[inline] + fn remaining(&self) -> usize { + self.inner.remaining() + } + + #[inline] + fn bytes(&self) -> &[u8] { + self.bytes_called.set(true); + self.inner.bytes() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.inner.advance(cnt) + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + self.bytes_vec_called.set(true); + self.inner.bytes_vectored(dst) + } +} + +impl<'a, B: Buf + 'a> Drop for WriteBufAuto<'a, B> { + fn drop(&mut self) { + if let WriteStrategy::Auto = self.inner.strategy { + if self.bytes_vec_called.get() { + self.inner.strategy = WriteStrategy::Queue; + } else if self.bytes_called.get() { + trace!("detected no usage of vectored write, flattening"); + self.inner.strategy = WriteStrategy::Flatten; + self.inner.headers.bytes.put(&mut self.inner.queue); + } + } + } +} + +#[derive(Debug)] +enum WriteStrategy { + Auto, + Flatten, + Queue, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use tokio_test::io::Builder as Mock; + + #[cfg(feature = "nightly")] + use test::Bencher; + + /* + impl<T: Read> MemRead for AsyncIo<T> { + fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { + let mut v = vec![0; len]; + let n = try_nb!(self.read(v.as_mut_slice())); + Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + } + } + */ + + #[tokio::test] + async fn iobuf_write_empty_slice() { + // First, let's just check that the Mock would normally return an + // error on an unexpected write, even if the buffer is empty... + let mut mock = Mock::new().build(); + futures_util::future::poll_fn(|cx| { + Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) + }) + .await + .expect_err("should be a broken pipe"); + + // underlying io will return the logic error upon write, + // so we are testing that the io_buf does not trigger a write + // when there is nothing to flush + let mock = Mock::new().build(); + let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + io_buf.flush().await.expect("should short-circuit flush"); + } + + #[tokio::test] + async fn parse_reads_until_blocked() { + use crate::proto::h1::ClientTransaction; + + let mock = Mock::new() + // Split over multiple reads will read all of it + .read(b"HTTP/1.1 200 OK\r\n") + .read(b"Server: hyper\r\n") + // missing last line ending + .wait(Duration::from_secs(1)) + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // We expect a `parse` to be not ready, and so can't await it directly. + // Rather, this `poll_fn` will wrap the `Poll` result. + futures_util::future::poll_fn(|cx| { + let parse_ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + assert!(buffered + .parse::<ClientTransaction>(cx, parse_ctx) + .is_pending()); + Poll::Ready(()) + }) + .await; + + assert_eq!( + buffered.read_buf, + b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..] + ); + } + + #[test] + fn read_strategy_adaptive_increments() { + let mut strategy = ReadStrategy::default(); + assert_eq!(strategy.next(), 8192); + + // Grows if record == next + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(16384); + assert_eq!(strategy.next(), 32768); + + // Enormous records still increment at same rate + strategy.record(::std::usize::MAX); + assert_eq!(strategy.next(), 65536); + + let max = strategy.max(); + while strategy.next() < max { + strategy.record(max); + } + + assert_eq!(strategy.next(), max, "never goes over max"); + strategy.record(max + 1); + assert_eq!(strategy.next(), max, "never goes over max"); + } + + #[test] + fn read_strategy_adaptive_decrements() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + strategy.record(8192); + assert_eq!(strategy.next(), 16384, "record was with range"); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "in-range record should make this the 'first' again" + ); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "second smaller record decrements"); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "first doesn't decrement"); + strategy.record(1); + assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum"); + } + + #[test] + fn read_strategy_adaptive_stays_the_same() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "with current step does not decrement" + ); + } + + #[test] + fn read_strategy_adaptive_max_fuzz() { + fn fuzz(max: usize) { + let mut strategy = ReadStrategy::with_max(max); + while strategy.next() < max { + strategy.record(::std::usize::MAX); + } + let mut next = strategy.next(); + while next > 8192 { + strategy.record(1); + strategy.record(1); + next = strategy.next(); + assert!( + next.is_power_of_two(), + "decrement should be powers of two: {} (max = {})", + next, + max, + ); + } + } + + let mut max = 8192; + while max < std::usize::MAX { + fuzz(max); + max = (max / 2).saturating_mul(3); + } + fuzz(::std::usize::MAX); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] // needs to trigger a debug_assert + fn write_buf_requires_non_empty_bufs() { + let mock = Mock::new().build(); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + buffered.buffer(Cursor::new(Vec::new())); + } + + /* + TODO: needs tokio_test::io to allow configure write_buf calls + #[test] + fn write_buf_queue() { + let _ = pretty_env_logger::try_init(); + + let mock = AsyncIo::new_buf(vec![], 1024); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + buffered.flush().unwrap(); + + assert_eq!(buffered.io, b"hello world, it's hyper!"); + assert_eq!(buffered.io.num_writes(), 1); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + */ + + #[tokio::test] + async fn write_buf_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + // Just a single write + .write(b"hello world, it's hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Flatten); + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + + buffered.flush().await.expect("flush"); + } + + #[tokio::test] + async fn write_buf_auto_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + // Expects write_buf to only consume first buffer + .write(b"hello ") + // And then the Auto strategy will have flattened + .write(b"world, it's hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // we have 4 buffers, but hope to detect that vectored IO isn't + // being used, and switch to flattening automatically, + // resulting in only 2 writes + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + #[tokio::test] + async fn write_buf_queue_disable_auto() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + .write(b"hello ") + .write(b"world, ") + .write(b"it's ") + .write(b"hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Queue); + + // we have 4 buffers, and vec IO disabled, but explicitly said + // don't try to auto detect (via setting strategy above) + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { + let s = "Hello, World!"; + b.bytes = s.len() as u64; + + let mut write_buf = WriteBuf::<bytes::Bytes>::new(); + write_buf.set_strategy(WriteStrategy::Flatten); + b.iter(|| { + let chunk = bytes::Bytes::from(s); + write_buf.buffer(chunk); + ::test::black_box(&write_buf); + write_buf.headers.bytes.clear(); + }) + } +} diff --git a/third_party/rust/hyper/src/proto/h1/mod.rs b/third_party/rust/hyper/src/proto/h1/mod.rs new file mode 100644 index 0000000000..2d0bf39bc9 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/mod.rs @@ -0,0 +1,95 @@ +use bytes::BytesMut; +use http::{HeaderMap, Method}; + +use crate::proto::{BodyLength, DecodedLength, MessageHead}; + +pub(crate) use self::conn::Conn; +pub use self::decode::Decoder; +pub(crate) use self::dispatch::Dispatcher; +pub use self::encode::{EncodedBuf, Encoder}; +pub use self::io::Cursor; //TODO: move out of h1::io +pub use self::io::MINIMUM_MAX_BUFFER_SIZE; + +mod conn; +pub(super) mod date; +mod decode; +pub(crate) mod dispatch; +mod encode; +mod io; +mod role; + +pub(crate) type ServerTransaction = role::Server; +pub(crate) type ClientTransaction = role::Client; + +pub(crate) trait Http1Transaction { + type Incoming; + type Outgoing: Default; + const LOG: &'static str; + fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>; + fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>; + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>; + + fn is_client() -> bool { + !Self::is_server() + } + + fn is_server() -> bool { + !Self::is_client() + } + + fn should_error_on_parse_eof() -> bool { + Self::is_client() + } + + fn should_read_first() -> bool { + Self::is_server() + } + + fn update_date() {} +} + +/// Result newtype for Http1Transaction::parse. +pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>; + +#[derive(Debug)] +pub(crate) struct ParsedMessage<T> { + head: MessageHead<T>, + decode: DecodedLength, + expect_continue: bool, + keep_alive: bool, + wants_upgrade: bool, +} + +pub(crate) struct ParseContext<'a> { + cached_headers: &'a mut Option<HeaderMap>, + req_method: &'a mut Option<Method>, +} + +/// Passed to Http1Transaction::encode +pub(crate) struct Encode<'a, T> { + head: &'a mut MessageHead<T>, + body: Option<BodyLength>, + keep_alive: bool, + req_method: &'a mut Option<Method>, + title_case_headers: bool, +} + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/third_party/rust/hyper/src/proto/h1/role.rs b/third_party/rust/hyper/src/proto/h1/role.rs new file mode 100644 index 0000000000..e99f4cf541 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/role.rs @@ -0,0 +1,1835 @@ +// `mem::uninitialized` replaced with `mem::MaybeUninit`, +// can't upgrade yet +#![allow(deprecated)] + +use std::fmt::{self, Write}; +use std::mem; + +use bytes::BytesMut; +use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::{HeaderMap, Method, StatusCode, Version}; + +use crate::error::Parse; +use crate::headers; +use crate::proto::h1::{ + date, Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, +}; +use crate::proto::{BodyLength, DecodedLength, MessageHead, RequestHead, RequestLine}; + +const MAX_HEADERS: usize = 100; +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific + +macro_rules! header_name { + ($bytes:expr) => {{ + #[cfg(debug_assertions)] + { + match HeaderName::from_bytes($bytes) { + Ok(name) => name, + Err(_) => panic!( + "illegal header name from httparse: {:?}", + ::bytes::Bytes::copy_from_slice($bytes) + ), + } + } + + #[cfg(not(debug_assertions))] + { + HeaderName::from_bytes($bytes).expect("header name validated by httparse") + } + }}; +} + +macro_rules! header_value { + ($bytes:expr) => {{ + #[cfg(debug_assertions)] + { + let __hvb: ::bytes::Bytes = $bytes; + match HeaderValue::from_maybe_shared(__hvb.clone()) { + Ok(name) => name, + Err(_) => panic!("illegal header value from httparse: {:?}", __hvb), + } + } + + #[cfg(not(debug_assertions))] + { + // Unsafe: httparse already validated header value + unsafe { HeaderValue::from_maybe_shared_unchecked($bytes) } + } + }}; +} + +// There are 2 main roles, Client and Server. + +pub(crate) enum Client {} + +pub(crate) enum Server {} + +impl Http1Transaction for Server { + type Incoming = RequestLine; + type Outgoing = StatusCode; + const LOG: &'static str = "{role=server}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<RequestLine> { + if buf.is_empty() { + return Ok(None); + } + + let mut keep_alive; + let is_http_11; + let subject; + let version; + let len; + let headers_len; + + // Unsafe: both headers_indices and headers are using uninitialized memory, + // but we *never* read any of it until after httparse has assigned + // values into it. By not zeroing out the stack memory, this saves + // a good ~5% on pipeline benchmarks. + let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; + { + let mut headers: [httparse::Header<'_>; MAX_HEADERS] = unsafe { mem::uninitialized() }; + trace!( + "Request.parse([Header; {}], [u8; {}])", + headers.len(), + buf.len() + ); + let mut req = httparse::Request::new(&mut headers); + let bytes = buf.as_ref(); + match req.parse(bytes) { + Ok(httparse::Status::Complete(parsed_len)) => { + trace!("Request.parse Complete({})", parsed_len); + len = parsed_len; + subject = RequestLine( + Method::from_bytes(req.method.unwrap().as_bytes())?, + req.path.unwrap().parse()?, + ); + version = if req.version.unwrap() == 1 { + keep_alive = true; + is_http_11 = true; + Version::HTTP_11 + } else { + keep_alive = false; + is_http_11 = false; + Version::HTTP_10 + }; + + record_header_indices(bytes, &req.headers, &mut headers_indices)?; + headers_len = req.headers.len(); + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(err) => { + return Err(match err { + // if invalid Token, try to determine if for method or path + httparse::Error::Token => { + if req.method.is_none() { + Parse::Method + } else { + debug_assert!(req.path.is_none()); + Parse::Uri + } + } + other => other.into(), + }); + } + } + }; + + let slice = buf.split_to(len).freeze(); + + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. (irrelevant to Request) + // 2. (irrelevant to Request) + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Length 0. + // 7. (irrelevant to Request) + + let mut decoder = DecodedLength::ZERO; + let mut expect_continue = false; + let mut con_len = None; + let mut is_te = false; + let mut is_te_chunked = false; + let mut wants_upgrade = subject.0 == Method::CONNECT; + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + + for header in &headers_indices[..headers_len] { + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + match name { + header::TRANSFER_ENCODING => { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if !is_http_11 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + return Err(Parse::Header); + } + is_te = true; + if headers::is_chunked_(&value) { + is_te_chunked = true; + decoder = DecodedLength::CHUNKED; + } + } + header::CONTENT_LENGTH => { + if is_te { + continue; + } + let len = value + .to_str() + .map_err(|_| Parse::Header) + .and_then(|s| s.parse().map_err(|_| Parse::Header))?; + if let Some(prev) = con_len { + if prev != len { + debug!( + "multiple Content-Length headers with different values: [{}, {}]", + prev, len, + ); + return Err(Parse::Header); + } + // we don't need to append this secondary length + continue; + } + decoder = DecodedLength::checked_new(len)?; + con_len = Some(len); + } + header::CONNECTION => { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + header::EXPECT => { + expect_continue = value.as_bytes() == b"100-continue"; + } + header::UPGRADE => { + // Upgrades are only allowed with HTTP/1.1 + wants_upgrade = is_http_11; + } + + _ => (), + } + + headers.append(name, value); + } + + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::Header); + } + + *ctx.req_method = Some(subject.0.clone()); + + Ok(Some(ParsedMessage { + head: MessageHead { + version, + subject, + headers, + }, + decode: decoder, + expect_continue, + keep_alive, + wants_upgrade, + })) + } + + fn encode( + mut msg: Encode<'_, Self::Outgoing>, + mut dst: &mut Vec<u8>, + ) -> crate::Result<Encoder> { + trace!( + "Server::encode status={:?}, body={:?}, req_method={:?}", + msg.head.subject, + msg.body, + msg.req_method + ); + debug_assert!( + !msg.title_case_headers, + "no server config for title case headers" + ); + + let mut wrote_len = false; + + // hyper currently doesn't support returning 1xx status codes as a Response + // This is because Service only allows returning a single Response, and + // so if you try to reply with a e.g. 100 Continue, you have no way of + // replying with the latter status code response. + let (ret, mut is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + (Ok(()), true) + } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() { + // Sending content-length or transfer-encoding header on 2xx response + // to CONNECT is forbidden in RFC 7231. + wrote_len = true; + (Ok(()), true) + } else if msg.head.subject.is_informational() { + warn!("response with 1xx status code not supported"); + *msg.head = MessageHead::default(); + msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; + (Err(crate::Error::new_user_unsupported_status_code()), true) + } else { + (Ok(()), !msg.keep_alive) + }; + + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let orig_len = dst.len(); + let rewind = |dst: &mut Vec<u8>| { + dst.truncate(orig_len); + }; + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + if msg.head.version == Version::HTTP_11 && msg.head.subject == StatusCode::OK { + extend(dst, b"HTTP/1.1 200 OK\r\n"); + } else { + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0 "), + Version::HTTP_11 => extend(dst, b"HTTP/1.1 "), + Version::HTTP_2 => { + warn!("response with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1 "); + } + other => panic!("unexpected response version: {:?}", other), + } + + extend(dst, msg.head.subject.as_str().as_bytes()); + extend(dst, b" "); + // a reason MUST be written, as many parsers will expect it. + extend( + dst, + msg.head + .subject + .canonical_reason() + .unwrap_or("<none>") + .as_bytes(), + ); + extend(dst, b"\r\n"); + } + + let mut encoder = Encoder::length(0); + let mut wrote_date = false; + let mut cur_name = None; + let mut is_name_written = false; + let mut must_write_chunked = false; + let mut prev_con_len = None; + + macro_rules! handle_is_name_written { + () => {{ + if is_name_written { + // we need to clean up and write the newline + debug_assert_ne!( + &dst[dst.len() - 2..], + b"\r\n", + "previous header wrote newline but set is_name_written" + ); + + if must_write_chunked { + extend(dst, b", chunked\r\n"); + } else { + extend(dst, b"\r\n"); + } + } + }}; + } + + 'headers: for (opt_name, value) in msg.head.headers.drain() { + if let Some(n) = opt_name { + cur_name = Some(n); + handle_is_name_written!(); + is_name_written = false; + } + let name = cur_name.as_ref().expect("current header name"); + match *name { + header::CONTENT_LENGTH => { + if wrote_len && !is_name_written { + warn!("unexpected content-length found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + match msg.body { + Some(BodyLength::Known(known_len)) => { + // The Payload claims to know a length, and + // the headers are already set. For performance + // reasons, we are just going to trust that + // the values match. + // + // In debug builds, we'll assert they are the + // same to help developers find bugs. + #[cfg(debug_assertions)] + { + if let Some(len) = headers::content_length_parse(&value) { + assert!( + len == known_len, + "payload claims content-length of {}, custom content-length header claims {}", + known_len, + len, + ); + } + } + + if !is_name_written { + encoder = Encoder::length(known_len); + extend(dst, b"content-length: "); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + } + continue 'headers; + } + Some(BodyLength::Unknown) => { + // The Payload impl didn't know how long the + // body is, but a length header was included. + // We have to parse the value to return our + // Encoder... + + if let Some(len) = headers::content_length_parse(&value) { + if let Some(prev) = prev_con_len { + if prev != len { + warn!( + "multiple Content-Length values found: [{}, {}]", + prev, len + ); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + debug_assert!(is_name_written); + continue 'headers; + } else { + // we haven't written content-length yet! + encoder = Encoder::length(len); + extend(dst, b"content-length: "); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + prev_con_len = Some(len); + continue 'headers; + } + } else { + warn!("illegal Content-Length value: {:?}", value); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + } + None => { + // We have no body to actually send, + // but the headers claim a content-length. + // There's only 2 ways this makes sense: + // + // - The header says the length is `0`. + // - This is a response to a `HEAD` request. + if msg.req_method == &Some(Method::HEAD) { + debug_assert_eq!(encoder, Encoder::length(0)); + } else { + if value.as_bytes() != b"0" { + warn!( + "content-length value found, but empty body provided: {:?}", + value + ); + } + continue 'headers; + } + } + } + wrote_len = true; + } + header::TRANSFER_ENCODING => { + if wrote_len && !is_name_written { + warn!("unexpected transfer-encoding found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + wrote_len = true; + // Must check each value, because `chunked` needs to be the + // last encoding, or else we add it. + must_write_chunked = !headers::is_chunked_(&value); + + if !is_name_written { + encoder = Encoder::chunked(); + is_name_written = true; + extend(dst, b"transfer-encoding: "); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::CONNECTION => { + if !is_last && headers::connection_close(&value) { + is_last = true; + } + if !is_name_written { + is_name_written = true; + extend(dst, b"connection: "); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::DATE => { + wrote_date = true; + } + _ => (), + } + //TODO: this should perhaps instead combine them into + //single lines, as RFC7230 suggests is preferable. + + // non-special write Name and Value + debug_assert!( + !is_name_written, + "{:?} set is_name_written and didn't continue loop", + name, + ); + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + + handle_is_name_written!(); + + if !wrote_len { + encoder = match msg.body { + Some(BodyLength::Unknown) => { + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + Encoder::close_delimited() + } else { + extend(dst, b"transfer-encoding: chunked\r\n"); + Encoder::chunked() + } + } + None | Some(BodyLength::Known(0)) => { + if msg.head.subject != StatusCode::NOT_MODIFIED { + extend(dst, b"content-length: 0\r\n"); + } + Encoder::length(0) + } + Some(BodyLength::Known(len)) => { + if msg.head.subject == StatusCode::NOT_MODIFIED { + Encoder::length(0) + } else { + extend(dst, b"content-length: "); + let _ = ::itoa::write(&mut dst, len); + extend(dst, b"\r\n"); + Encoder::length(len) + } + } + }; + } + + if !Server::can_have_body(msg.req_method, msg.head.subject) { + trace!( + "server body forced to 0; method={:?}, status={:?}", + msg.req_method, + msg.head.subject + ); + encoder = Encoder::length(0); + } + + // cached date is much faster than formatting every request + if !wrote_date { + dst.reserve(date::DATE_VALUE_LENGTH + 8); + extend(dst, b"date: "); + date::extend(dst); + extend(dst, b"\r\n\r\n"); + } else { + extend(dst, b"\r\n"); + } + + ret.map(|()| encoder.set_last(is_last)) + } + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + use crate::error::Kind; + let status = match *err.kind() { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + + fn is_server() -> bool { + true + } + + fn update_date() { + date::update(); + } +} + +impl Server { + fn can_have_body(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_chunked(method, status) + } + + fn can_chunked(method: &Option<Method>, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() + { + false + } else { + match status { + // TODO: support for 1xx codes needs improvement everywhere + // would be 100...199 => false + StatusCode::SWITCHING_PROTOCOLS + | StatusCode::NO_CONTENT + | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } +} + +impl Http1Transaction for Client { + type Incoming = StatusCode; + type Outgoing = RequestLine; + const LOG: &'static str = "{role=client}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> { + // Loop to skip information status code headers (100 Continue, etc). + loop { + if buf.is_empty() { + return Ok(None); + } + // Unsafe: see comment in Server Http1Transaction, above. + let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; + let (len, status, version, headers_len) = { + let mut headers: [httparse::Header<'_>; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + trace!( + "Response.parse([Header; {}], [u8; {}])", + headers.len(), + buf.len() + ); + let mut res = httparse::Response::new(&mut headers); + let bytes = buf.as_ref(); + match res.parse(bytes)? { + httparse::Status::Complete(len) => { + trace!("Response.parse Complete({})", len); + let status = StatusCode::from_u16(res.code.unwrap())?; + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + record_header_indices(bytes, &res.headers, &mut headers_indices)?; + let headers_len = res.headers.len(); + (len, status, version, headers_len) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let slice = buf.split_to(len).freeze(); + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + let mut keep_alive = version == Version::HTTP_11; + + headers.reserve(headers_len); + for header in &headers_indices[..headers_len] { + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + if let header::CONNECTION = name { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + headers.append(name, value); + } + + let head = MessageHead { + version, + subject: status, + headers, + }; + if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { + return Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + // a client upgrade means the connection can't be used + // again, as it is definitely upgrading. + keep_alive: keep_alive && !is_upgrade, + wants_upgrade: is_upgrade, + })); + } + } + } + + fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Client::encode method={:?}, body={:?}", + msg.head.subject.0, + msg.body + ); + + *msg.req_method = Some(msg.head.subject.0.clone()); + + let body = Client::set_length(msg.head, msg.body); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + extend(dst, msg.head.subject.0.as_str().as_bytes()); + extend(dst, b" "); + //TODO: add API to http::Uri to encode without std::fmt + let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1); + + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0"), + Version::HTTP_11 => extend(dst, b"HTTP/1.1"), + Version::HTTP_2 => { + warn!("request with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1"); + } + other => panic!("unexpected request version: {:?}", other), + } + extend(dst, b"\r\n"); + + if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + extend(dst, b"\r\n"); + msg.head.headers.clear(); //TODO: remove when switching to drain() + + Ok(body) + } + + fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + // we can't tell the server about any errors it creates + None + } + + fn is_client() -> bool { + true + } +} + +impl Client { + /// Returns Some(length, wants_upgrade) if successful. + /// + /// Returns None if this message head should be skipped (like a 100 status). + fn decoder( + inc: &MessageHead<StatusCode>, + method: &mut Option<Method>, + ) -> Result<Option<(DecodedLength, bool)>, Parse> { + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. (irrelevant to Response) + // 7. Read till EOF. + + match inc.subject.as_u16() { + 101 => { + return Ok(Some((DecodedLength::ZERO, true))); + } + 100 | 102..=199 => { + trace!("ignoring informational response: {}", inc.subject.as_u16()); + return Ok(None); + } + 204 | 304 => return Ok(Some((DecodedLength::ZERO, false))), + _ => (), + } + match *method { + Some(Method::HEAD) => { + return Ok(Some((DecodedLength::ZERO, false))); + } + Some(Method::CONNECT) => { + if let 200..=299 = inc.subject.as_u16() { + return Ok(Some((DecodedLength::ZERO, true))); + } + } + Some(_) => {} + None => { + trace!("Client::decoder is missing the Method"); + } + } + + if inc.headers.contains_key(header::TRANSFER_ENCODING) { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if inc.version == Version::HTTP_10 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + Err(Parse::Header) + } else if headers::transfer_encoding_is_chunked(&inc.headers) { + Ok(Some((DecodedLength::CHUNKED, false))) + } else { + trace!("not chunked, read till eof"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { + Ok(Some((DecodedLength::checked_new(len)?, false))) + } else if inc.headers.contains_key(header::CONTENT_LENGTH) { + debug!("illegal Content-Length header"); + Err(Parse::Header) + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } +} + +impl Client { + fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder { + let body = if let Some(body) = body { + body + } else { + head.headers.remove(header::TRANSFER_ENCODING); + return Encoder::length(0); + }; + + // HTTP/1.0 doesn't know about chunked + let can_chunked = head.version == Version::HTTP_11; + let headers = &mut head.headers; + + // If the user already set specific headers, we should respect them, regardless + // of what the Payload knows about itself. They set them for a reason. + + // Because of the borrow checker, we can't check the for an existing + // Content-Length header while holding an `Entry` for the Transfer-Encoding + // header, so unfortunately, we must do the check here, first. + + let existing_con_len = headers::content_length_parse_all(headers); + let mut should_remove_con_len = false; + + if !can_chunked { + // Chunked isn't legal, so if it is set, we need to remove it. + if headers.remove(header::TRANSFER_ENCODING).is_some() { + trace!("removing illegal transfer-encoding header"); + } + + return if let Some(len) = existing_con_len { + Encoder::length(len) + } else if let BodyLength::Known(len) = body { + set_content_length(headers, len) + } else { + // HTTP/1.0 client requests without a content-length + // cannot have any body at all. + Encoder::length(0) + }; + } + + // If the user set a transfer-encoding, respect that. Let's just + // make sure `chunked` is the final encoding. + let encoder = match headers.entry(header::TRANSFER_ENCODING) { + Entry::Occupied(te) => { + should_remove_con_len = true; + if headers::is_chunked(te.iter()) { + Some(Encoder::chunked()) + } else { + warn!("user provided transfer-encoding does not end in 'chunked'"); + + // There's a Transfer-Encoding, but it doesn't end in 'chunked'! + // An example that could trigger this: + // + // Transfer-Encoding: gzip + // + // This can be bad, depending on if this is a request or a + // response. + // + // - A request is illegal if there is a `Transfer-Encoding` + // but it doesn't end in `chunked`. + // - A response that has `Transfer-Encoding` but doesn't + // end in `chunked` isn't illegal, it just forces this + // to be close-delimited. + // + // We can try to repair this, by adding `chunked` ourselves. + + headers::add_chunked(te); + Some(Encoder::chunked()) + } + } + Entry::Vacant(te) => { + if let Some(len) = existing_con_len { + Some(Encoder::length(len)) + } else if let BodyLength::Unknown = body { + // GET, HEAD, and CONNECT almost never have bodies. + // + // So instead of sending a "chunked" body with a 0-chunk, + // assume no body here. If you *must* send a body, + // set the headers explicitly. + match head.subject.0 { + Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)), + _ => { + te.insert(HeaderValue::from_static("chunked")); + Some(Encoder::chunked()) + } + } + } else { + None + } + } + }; + + // This is because we need a second mutable borrow to remove + // content-length header. + if let Some(encoder) = encoder { + if should_remove_con_len && existing_con_len.is_some() { + headers.remove(header::CONTENT_LENGTH); + } + return encoder; + } + + // User didn't set transfer-encoding, AND we know body length, + // so we can just set the Content-Length automatically. + + let len = if let BodyLength::Known(len) = body { + len + } else { + unreachable!("BodyLength::Unknown would set chunked"); + }; + + set_content_length(headers, len) + } +} + +fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { + // At this point, there should not be a valid Content-Length + // header. However, since we'll be indexing in anyways, we can + // warn the user if there was an existing illegal header. + // + // Or at least, we can in theory. It's actually a little bit slower, + // so perhaps only do that while the user is developing/testing. + + if cfg!(debug_assertions) { + match headers.entry(header::CONTENT_LENGTH) { + Entry::Occupied(mut cl) => { + // Internal sanity check, we should have already determined + // that the header was illegal before calling this function. + debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none()); + // Uh oh, the user set `Content-Length` headers, but set bad ones. + // This would be an illegal message anyways, so let's try to repair + // with our known good length. + error!("user provided content-length header was invalid"); + + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + Entry::Vacant(cl) => { + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + } + } else { + headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len)); + Encoder::length(len) + } +} + +#[derive(Clone, Copy)] +struct HeaderIndices { + name: (usize, usize), + value: (usize, usize), +} + +fn record_header_indices( + bytes: &[u8], + headers: &[httparse::Header<'_>], + indices: &mut [HeaderIndices], +) -> Result<(), crate::error::Parse> { + let bytes_ptr = bytes.as_ptr() as usize; + + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + if header.name.len() >= (1 << 16) { + debug!("header name larger than 64kb: {:?}", header.name); + return Err(crate::error::Parse::TooLarge); + } + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + indices.name = (name_start, name_end); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + indices.value = (value_start, value_end); + } + + Ok(()) +} + +// Write header names as title case. The header name is assumed to be ASCII, +// therefore it is trivial to convert an ASCII character from lowercase to +// uppercase. It is as simple as XORing the lowercase character byte with +// space. +fn title_case(dst: &mut Vec<u8>, name: &[u8]) { + dst.reserve(name.len()); + + let mut iter = name.iter(); + + // Uppercase the first character + if let Some(c) = iter.next() { + if *c >= b'a' && *c <= b'z' { + dst.push(*c ^ b' '); + } else { + dst.push(*c); + } + } + + while let Some(c) = iter.next() { + dst.push(*c); + + if *c == b'-' { + if let Some(c) = iter.next() { + if *c >= b'a' && *c <= b'z' { + dst.push(*c ^ b' '); + } else { + dst.push(*c); + } + } + } + } +} + +fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + title_case(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +struct FastWrite<'a>(&'a mut Vec<u8>); + +impl<'a> fmt::Write for FastWrite<'a> { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + extend(self.0, s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +fn extend(dst: &mut Vec<u8>, data: &[u8]) { + dst.extend_from_slice(data); +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + + use super::*; + + #[test] + fn test_parse_request() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut method = None; + let msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut None, + req_method: &mut method, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject.0, crate::Method::GET); + assert_eq!(msg.head.subject.1, "/echo"); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Host"], "hyper.rs"); + assert_eq!(method, Some(crate::Method::GET)); + } + + #[test] + fn test_parse_response() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Content-Length"], "0"); + } + + #[test] + fn test_parse_request_errors() { + let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + Server::parse(&mut raw, ctx).unwrap_err(); + } + + #[test] + fn test_decoder_request() { + fn parse(s: &str) -> ParsedMessage<RequestLine> { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str, comment: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }, + ) + .expect_err(comment) + } + + // no length or transfer-encoding means 0-length body + assert_eq!( + parse( + "\ + GET / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip, chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // content-length + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // multiple content-lengths of same value are fine + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // multiple content-lengths with different values is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 11\r\n\ + \r\n\ + ", + "multiple content-lengths", + ); + + // transfer-encoding that isn't chunked is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + \r\n\ + ", + "transfer-encoding but not chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked, gzip\r\n\ + \r\n\ + ", + "transfer-encoding doesn't end in chunked", + ); + + // http/1.0 + + assert_eq!( + parse( + "\ + POST / HTTP/1.0\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // 1.0 doesn't understand chunked, so its an error + parse_err( + "\ + POST / HTTP/1.0\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + "1.0 chunked", + ); + } + + #[test] + fn test_decoder_response() { + fn parse(s: &str) -> ParsedMessage<StatusCode> { + parse_with_method(s, Method::GET) + } + + fn parse_ignores(s: &str) { + let mut bytes = BytesMut::from(s); + assert!(Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + } + ) + .expect("parse ok") + .is_none()) + } + + fn parse_with_method(s: &str, m: Method) -> ParsedMessage<StatusCode> { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(m), + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }, + ) + .expect_err("parse should err") + } + + // no content-length or transfer-encoding means close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 204 and 304 never have a body + assert_eq!( + parse( + "\ + HTTP/1.1 204 No Content\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 304 Not Modified\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // content-length + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 9\r\n\ + \r\n\ + ", + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // transfer-encoding not-chunked is close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: yolo\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // HEAD can have content-length, but not body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ", + Method::HEAD + ) + .decode, + DecodedLength::ZERO + ); + + // CONNECT with 200 never has body + { + let msg = parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", + Method::CONNECT, + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be upgrade"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // CONNECT receiving non 200 can have a body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 400 Bad Request\r\n\ + \r\n\ + ", + Method::CONNECT + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1xx status codes + parse_ignores( + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + ", + ); + + parse_ignores( + "\ + HTTP/1.1 103 Early Hints\r\n\ + \r\n\ + ", + ); + + // 101 upgrade not supported yet + { + let msg = parse( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + ", + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be last"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // http/1.0 + assert_eq!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1.0 doesn't understand chunked + parse_err( + "\ + HTTP/1.0 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + ); + + // keep-alive + assert!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.1 keep-alive is default" + ); + + assert!( + !parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, close, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection close is always close" + ); + + assert!( + !parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.0 close is default" + ); + + assert!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, keep-alive, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection keep-alive is always keep-alive" + ); + } + + #[test] + fn test_client_request_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers.insert("*-*", HeaderValue::from_static("o_o")); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec()); + } + + #[test] + fn test_server_encode_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let encoder = Server::encode( + Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert!(encoder.is_last()); + } + + #[test] + fn parse_header_htabs() { + let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n"); + let parsed = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }, + ) + .expect("parse ok") + .expect("parse complete"); + + assert_eq!(parsed.head.headers["server"], "hello\tworld"); + } + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_incoming(b: &mut Bencher) { + let mut raw = BytesMut::from( + &b"GET /super_long_uri/and_whatever?what_should_we_talk_about/\ + I_wonder/Hard_to_write_in_an_uri_after_all/you_have_to_make\ + _up_the_punctuation_yourself/how_fun_is_that?test=foo&test1=\ + foo1&test2=foo2&test3=foo3&test4=foo4 HTTP/1.1\r\nHost: \ + hyper.rs\r\nAccept: a lot of things\r\nAccept-Charset: \ + utf8\r\nAccept-Encoding: *\r\nAccess-Control-Allow-\ + Credentials: None\r\nAccess-Control-Allow-Origin: None\r\n\ + Access-Control-Allow-Methods: None\r\nAccess-Control-Allow-\ + Headers: None\r\nContent-Encoding: utf8\r\nContent-Security-\ + Policy: None\r\nContent-Type: text/html\r\nOrigin: hyper\ + \r\nSec-Websocket-Extensions: It looks super important!\r\n\ + Sec-Websocket-Origin: hyper\r\nSec-Websocket-Version: 4.3\r\ + \nStrict-Transport-Security: None\r\nUser-Agent: hyper\r\n\ + X-Content-Duration: None\r\nX-Content-Security-Policy: None\ + \r\nX-DNSPrefetch-Control: None\r\nX-Frame-Options: \ + Something important obviously\r\nX-Requested-With: Nothing\ + \r\n\r\n"[..], + ); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_short(b: &mut Bencher) { + let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..]; + let mut raw = BytesMut::from(s); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_headers_preset(b: &mut Bencher) { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let len = 108; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut headers = HeaderMap::new(); + headers.insert("content-length", HeaderValue::from_static("10")); + headers.insert("content-type", HeaderValue::from_static("application/json")); + + b.iter(|| { + let mut vec = Vec::new(); + head.headers = headers.clone(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(vec); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_no_headers(b: &mut Bencher) { + use crate::proto::BodyLength; + + let len = 76; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut vec = Vec::with_capacity(128); + + b.iter(|| { + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(&vec); + + vec.clear(); + }) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/client.rs b/third_party/rust/hyper/src/proto/h2/client.rs new file mode 100644 index 0000000000..bf4cfccea5 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/client.rs @@ -0,0 +1,292 @@ +#[cfg(feature = "runtime")] +use std::time::Duration; + +use futures_channel::{mpsc, oneshot}; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use futures_util::stream::StreamExt as _; +use h2::client::{Builder, SendRequest}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use crate::body::Payload; +use crate::common::{task, Exec, Future, Never, Pin, Poll}; +use crate::headers; +use crate::proto::Dispatched; +use crate::{Body, Request, Response}; + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +///// An mpsc channel is used to help notify the `Connection` task when *all* +///// other handles to it have been dropped, so that it can shutdown. +type ConnDropRef = mpsc::Sender<Never>; + +///// A oneshot channel watches the `Connection` task, and when it completes, +///// the "dispatch" task will be notified and can shutdown sooner. +type ConnEof = oneshot::Receiver<Never>; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024 * 5; // 5mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024 * 2; // 2mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_while_idle: bool, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + #[cfg(feature = "runtime")] + keep_alive_while_idle: false, + } + } +} + +pub(crate) async fn handshake<T, B>( + io: T, + req_rx: ClientRx<B>, + config: &Config, + exec: Exec, +) -> crate::Result<ClientTask<B>> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: Payload, +{ + let (h2_tx, mut conn) = Builder::default() + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .enable_push(false) + .handshake::<_, SendBuf<B::Data>>(io) + .await + .map_err(crate::Error::new_h2)?; + + // An mpsc channel is used entirely to detect when the + // 'Client' has been dropped. This is to get around a bug + // in h2 where dropping all SendRequests won't notify a + // parked Connection. + let (conn_drop_ref, rx) = mpsc::channel(1); + let (cancel_tx, conn_eof) = oneshot::channel(); + + let conn_drop_rx = rx.into_future().map(|(item, _rx)| { + if let Some(never) = item { + match never {} + } + }); + + let ping_config = ping::Config { + bdp_initial_window: if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + #[cfg(feature = "runtime")] + keep_alive_while_idle: config.keep_alive_while_idle, + }; + + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + let (recorder, mut ponger) = ping::channel(pp, ping_config); + + let conn = future::poll_fn(move |cx| { + match ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + conn.set_target_window_size(wnd); + conn.set_initial_window_size(wnd)?; + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} + } + + Pin::new(&mut conn).poll(cx) + }); + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + recorder + } else { + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + ping::disabled() + }; + + Ok(ClientTask { + ping, + conn_drop_ref, + conn_eof, + executor: exec, + h2_tx, + req_rx, + }) +} + +async fn conn_task<C, D>(conn: C, drop_rx: D, cancel_tx: oneshot::Sender<Never>) +where + C: Future + Unpin, + D: Future<Output = ()> + Unpin, +{ + match future::select(conn, drop_rx).await { + Either::Left(_) => { + // ok or err, the `conn` has finished + } + Either::Right(((), conn)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(cancel_tx); + let _ = conn.await; + } + } +} + +pub(crate) struct ClientTask<B> +where + B: Payload, +{ + ping: ping::Recorder, + conn_drop_ref: ConnDropRef, + conn_eof: ConnEof, + executor: Exec, + h2_tx: SendRequest<SendBuf<B::Data>>, + req_rx: ClientRx<B>, +} + +impl<B> Future for ClientTask<B> +where + B: Payload + 'static, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(self.h2_tx.poll_ready(cx)) { + Ok(()) => (), + Err(err) => { + self.ping.ensure_not_timed_out()?; + return if err.reason() == Some(::h2::Reason::NO_ERROR) { + trace!("connection gracefully shutdown"); + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Ready(Err(crate::Error::new_h2(err))) + }; + } + }; + + match Pin::new(&mut self.req_rx).poll_next(cx) { + Poll::Ready(Some((req, cb))) => { + // check that future hasn't been canceled already + if cb.is_canceled() { + trace!("request callback is canceled"); + continue; + } + let (head, body) = req.into_parts(); + let mut req = ::http::Request::from_parts(head, ()); + super::strip_connection_headers(req.headers_mut(), true); + if let Some(len) = body.size_hint().exact() { + if len != 0 || headers::method_has_defined_payload_semantics(req.method()) { + headers::set_content_length_if_missing(req.headers_mut(), len); + } + } + let eos = body.is_end_stream(); + let (fut, body_tx) = match self.h2_tx.send_request(req, eos) { + Ok(ok) => ok, + Err(err) => { + debug!("client send request error: {}", err); + cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + }; + + let ping = self.ping.clone(); + if !eos { + let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| { + if let Err(e) = res { + debug!("client request body error: {}", e); + } + }); + + // eagerly see if the body pipe is ready and + // can thus skip allocating in the executor + match Pin::new(&mut pipe).poll(cx) { + Poll::Ready(_) => (), + Poll::Pending => { + let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); + let pipe = pipe.map(move |x| { + drop(conn_drop_ref); + drop(ping); + x + }); + self.executor.execute(pipe); + } + } + } + + let fut = fut.map(move |result| match result { + Ok(res) => { + // record that we got the response headers + ping.record_non_data(); + + let content_length = decode_content_length(res.headers()); + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length, ping) + }); + Ok(res) + } + Err(err) => { + ping.ensure_not_timed_out().map_err(|e| (e, None))?; + + debug!("client response error: {}", err); + Err((crate::Error::new_h2(err), None)) + } + }); + self.executor.execute(cb.send_when(fut)); + continue; + } + + Poll::Ready(None) => { + trace!("client::dispatch::Sender dropped"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + + Poll::Pending => match ready!(Pin::new(&mut self.conn_eof).poll(cx)) { + Ok(never) => match never {}, + Err(_conn_is_eof) => { + trace!("connection task is closed, closing dispatch task"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }, + } + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/mod.rs b/third_party/rust/hyper/src/proto/h2/mod.rs new file mode 100644 index 0000000000..e25f038cad --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/mod.rs @@ -0,0 +1,263 @@ +use bytes::Buf; +use h2::SendStream; +use http::header::{ + HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, + TRANSFER_ENCODING, UPGRADE, +}; +use http::HeaderMap; +use pin_project::pin_project; + +use super::DecodedLength; +use crate::body::Payload; +use crate::common::{task, Future, Pin, Poll}; +use crate::headers::content_length_parse_all; + +pub(crate) mod client; +pub(crate) mod ping; +pub(crate) mod server; + +pub(crate) use self::client::ClientTask; +pub(crate) use self::server::Server; + +/// Default initial stream window size defined in HTTP2 spec. +pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535; + +fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { + // List of connection headers from: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + // + // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're + // tested separately. + let connection_headers = [ + HeaderName::from_lowercase(b"keep-alive").unwrap(), + HeaderName::from_lowercase(b"proxy-connection").unwrap(), + PROXY_AUTHENTICATE, + PROXY_AUTHORIZATION, + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + ]; + + for header in connection_headers.iter() { + if headers.remove(header).is_some() { + warn!("Connection header illegal in HTTP/2: {}", header.as_str()); + } + } + + if is_request { + if headers + .get(TE) + .map(|te_header| te_header != "trailers") + .unwrap_or(false) + { + warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests"); + headers.remove(TE); + } + } else if headers.remove(TE).is_some() { + warn!("TE headers illegal in HTTP/2 responses"); + } + + if let Some(header) = headers.remove(CONNECTION) { + warn!( + "Connection header illegal in HTTP/2: {}", + CONNECTION.as_str() + ); + let header_contents = header.to_str().unwrap(); + + // A `Connection` header may have a comma-separated list of names of other headers that + // are meant for only this specific connection. + // + // Iterate these names and remove them as headers. Connection-specific headers are + // forbidden in HTTP2, as that information has been moved into frame types of the h2 + // protocol. + for name in header_contents.split(',') { + let name = name.trim(); + headers.remove(name); + } + } +} + +fn decode_content_length(headers: &HeaderMap) -> DecodedLength { + if let Some(len) = content_length_parse_all(headers) { + // If the length is u64::MAX, oh well, just reported chunked. + DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) + } else { + DecodedLength::CHUNKED + } +} + +// body adapters used by both Client and Server + +#[pin_project] +struct PipeToSendStream<S> +where + S: Payload, +{ + body_tx: SendStream<SendBuf<S::Data>>, + data_done: bool, + #[pin] + stream: S, +} + +impl<S> PipeToSendStream<S> +where + S: Payload, +{ + fn new(stream: S, tx: SendStream<SendBuf<S::Data>>) -> PipeToSendStream<S> { + PipeToSendStream { + body_tx: tx, + data_done: false, + stream, + } + } +} + +impl<S> Future for PipeToSendStream<S> +where + S: Payload, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + if !*me.data_done { + // we don't have the next chunk of data yet, so just reserve 1 byte to make + // sure there's some capacity available. h2 will handle the capacity management + // for the actual body chunk. + me.body_tx.reserve_capacity(1); + + if me.body_tx.capacity() == 0 { + loop { + match ready!(me.body_tx.poll_capacity(cx)) { + Some(Ok(0)) => {} + Some(Ok(_)) => break, + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_body_write(e))) + } + None => { + // None means the stream is no longer in a + // streaming state, we either finished it + // somehow, or the remote reset us. + return Poll::Ready(Err(crate::Error::new_body_write( + "send stream capacity unexpectedly closed", + ))); + } + } + } + } else if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.as_mut().poll_data(cx)) { + Some(Ok(chunk)) => { + let is_eos = me.stream.is_end_stream(); + trace!( + "send body chunk: {} bytes, eos={}", + chunk.remaining(), + is_eos, + ); + + let buf = SendBuf(Some(chunk)); + me.body_tx + .send_data(buf, is_eos) + .map_err(crate::Error::new_body_write)?; + + if is_eos { + return Poll::Ready(Ok(())); + } + } + Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + None => { + me.body_tx.reserve_capacity(0); + let is_eos = me.stream.is_end_stream(); + if is_eos { + return Poll::Ready(me.body_tx.send_eos_frame()); + } else { + *me.data_done = true; + // loop again to poll_trailers + } + } + } + } else { + if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.poll_trailers(cx)) { + Ok(Some(trailers)) => { + me.body_tx + .send_trailers(trailers) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + Ok(None) => { + // There were no trailers, so send an empty DATA frame... + return Poll::Ready(me.body_tx.send_eos_frame()); + } + Err(e) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + } + } + } + } +} + +trait SendStreamExt { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>; + fn send_eos_frame(&mut self) -> crate::Result<()>; +} + +impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>, + { + let err = crate::Error::new_user_body(err); + debug!("send body user stream error: {}", err); + self.send_reset(err.h2_reason()); + err + } + + fn send_eos_frame(&mut self) -> crate::Result<()> { + trace!("send body eos"); + self.send_data(SendBuf(None), true) + .map_err(crate::Error::new_body_write) + } +} + +struct SendBuf<B>(Option<B>); + +impl<B: Buf> Buf for SendBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + } + + #[inline] + fn bytes(&self) -> &[u8] { + self.0.as_ref().map(|b| b.bytes()).unwrap_or(&[]) + } + + #[inline] + fn advance(&mut self, cnt: usize) { + if let Some(b) = self.0.as_mut() { + b.advance(cnt) + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/ping.rs b/third_party/rust/hyper/src/proto/h2/ping.rs new file mode 100644 index 0000000000..c4fe2dd15c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/ping.rs @@ -0,0 +1,506 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. + +#[cfg(feature = "runtime")] +use std::fmt; +#[cfg(feature = "runtime")] +use std::future::Future; +#[cfg(feature = "runtime")] +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(not(feature = "runtime"))] +use std::time::Instant; + +use h2::{Ping, PingPong}; +#[cfg(feature = "runtime")] +use tokio::time::{Delay, Instant}; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + rtt: 0.0, + }); + + let bytes = bdp.as_ref().map(|_| 0); + + #[cfg(feature = "runtime")] + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + timer: tokio::time::delay_for(interval), + state: KeepAliveState::Init, + }); + + #[cfg(feature = "runtime")] + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + #[cfg(feature = "runtime")] + last_read_at, + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + #[cfg(feature = "runtime")] + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option<WindowSize>, + /// If no frames are received in this amount of time, a PING frame is sent. + #[cfg(feature = "runtime")] + pub(super) keep_alive_interval: Option<Duration>, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + #[cfg(feature = "runtime")] + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + #[cfg(feature = "runtime")] + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option<Arc<Mutex<Shared>>>, +} + +pub(super) struct Ponger { + bdp: Option<Bdp>, + #[cfg(feature = "runtime")] + keep_alive: Option<KeepAlive>, + shared: Arc<Mutex<Shared>>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option<Instant>, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option<usize>, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + #[cfg(feature = "runtime")] + last_read_at: Option<Instant>, + + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Round trip time in seconds + rtt: f64, +} + +#[cfg(feature = "runtime")] +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + + state: KeepAliveState, + timer: Delay, +} + +#[cfg(feature = "runtime")] +enum KeepAliveState { + Init, + Scheduled, + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + #[cfg(feature = "runtime")] + KeepAliveTimedOut, +} + +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + #[cfg(feature = "runtime")] + { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } + + #[cfg(not(feature = "runtime"))] + { + self.bdp_initial_window.is_some() + } + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + #[cfg(feature = "runtime")] + locked.update_last_read_at(); + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + #[cfg(feature = "runtime")] + { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + #[cfg(feature = "runtime")] + { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<Ponged> { + let mut locked = self.shared.lock().unwrap(); + #[cfg(feature = "runtime")] + let is_idle = self.is_idle(); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + ka.schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + let (bytes, rtt) = match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let rtt = locked + .ping_sent_at + .expect("pong received implies ping_sent_at") + .elapsed(); + locked.ping_sent_at = None; + trace!("recv pong"); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.schedule(is_idle, &locked); + } + } + + if self.bdp.is_some() { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + (bytes, rtt) + } else { + // no bdp, done! + return Poll::Pending; + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + return Poll::Pending; + } + Poll::Pending => { + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + + return Poll::Pending; + } + }; + + drop(locked); + + if let Some(bdp) = self.bdp.as_mut().and_then(|bdp| bdp.calculate(bytes, rtt)) { + Poll::Ready(Ponged::SizeUpdate(bdp)) + } else { + // XXX: this doesn't register a waker...? + Poll::Pending + } + } + + #[cfg(feature = "runtime")] + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + #[cfg(feature = "runtime")] + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + #[cfg(feature = "runtime")] + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option<WindowSize> { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.rtt == 0.0 { + // First sample means rtt is first rtt. + self.rtt = rtt; + } else { + // Weigh this rtt as 1/8 for a moving average. + self.rtt += (rtt - self.rtt) * 0.125; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if bytes >= self.bdp as usize * 2 / 3 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + Some(self.bdp) + } else { + None + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +#[cfg(feature = "runtime")] +impl KeepAlive { + fn schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.reset(interval); + } + KeepAliveState::Scheduled | KeepAliveState::PingSent => (), + } + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > self.timer.deadline() { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.reset(timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +#[cfg(feature = "runtime")] +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +#[cfg(feature = "runtime")] +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +#[cfg(feature = "runtime")] +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/server.rs b/third_party/rust/hyper/src/proto/h2/server.rs new file mode 100644 index 0000000000..bf81c1190f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/server.rs @@ -0,0 +1,439 @@ +use std::error::Error as StdError; +use std::marker::Unpin; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use h2::server::{Connection, Handshake, SendResponse}; +use h2::Reason; +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use crate::body::Payload; +use crate::common::exec::H2Exec; +use crate::common::{task, Future, Pin, Poll}; +use crate::headers; +use crate::proto::Dispatched; +use crate::service::HttpService; + +use crate::{Body, Response}; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +// +// At the same time, a server more often has multiple clients connected, and +// so is more likely to use more resources than a client would. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_concurrent_streams: Option<u32>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_concurrent_streams: None, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + } + } +} + +#[pin_project] +pub(crate) struct Server<T, S, B, E> +where + S: HttpService<Body>, + B: Payload, +{ + exec: E, + service: S, + state: State<T, B>, +} + +enum State<T, B> +where + B: Payload, +{ + Handshaking { + ping_config: ping::Config, + hs: Handshake<T, SendBuf<B::Data>>, + }, + Serving(Serving<T, B>), + Closed, +} + +struct Serving<T, B> +where + B: Payload, +{ + ping: Option<(ping::Recorder, ping::Ponger)>, + conn: Connection<T, SendBuf<B::Data>>, + closing: Option<crate::Error>, +} + +impl<T, S, B, E> Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server<T, S, B, E> { + let mut builder = h2::server::Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size); + if let Some(max) = config.max_concurrent_streams { + builder.max_concurrent_streams(max); + } + let handshake = builder.handshake(io); + + let bdp = if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }; + + let ping_config = ping::Config { + bdp_initial_window: bdp, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + // If keep-alive is enabled for servers, always enabled while + // idle, so it can more aggresively close dead connections. + #[cfg(feature = "runtime")] + keep_alive_while_idle: true, + }; + + Server { + exec, + state: State::Handshaking { + ping_config, + hs: handshake, + }, + service, + } + } + + pub fn graceful_shutdown(&mut self) { + trace!("graceful_shutdown"); + match self.state { + State::Handshaking { .. } => { + // fall-through, to replace state with Closed + } + State::Serving(ref mut srv) => { + if srv.closing.is_none() { + srv.conn.graceful_shutdown(); + } + return; + } + State::Closed => { + return; + } + } + self.state = State::Closed; + } +} + +impl<T, S, B, E> Future for Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = &mut *self; + loop { + let next = match me.state { + State::Handshaking { + ref mut hs, + ref ping_config, + } => { + let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + Some(ping::channel(pp, ping_config.clone())) + } else { + None + }; + State::Serving(Serving { + ping, + conn, + closing: None, + }) + } + State::Serving(ref mut srv) => { + ready!(srv.poll_server(cx, &mut me.service, &mut me.exec))?; + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + State::Closed => { + // graceful_shutdown was called before handshaking finished, + // nothing to do here... + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }; + me.state = next; + } + } +} + +impl<T, B> Serving<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Payload, +{ + fn poll_server<S, E>( + &mut self, + cx: &mut task::Context<'_>, + service: &mut S, + exec: &mut E, + ) -> Poll<crate::Result<()>> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + E: H2Exec<S::Future, B>, + { + if self.closing.is_none() { + loop { + self.poll_ping(cx); + + // Check that the service is ready to accept a new request. + // + // - If not, just drive the connection some. + // - If ready, try to accept a new request from the connection. + match service.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + // use `poll_closed` instead of `poll_accept`, + // in order to avoid accepting a request. + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => { + let err = crate::Error::new_user_service(err); + debug!("service closed: {}", err); + + let reason = err.h2_reason(); + if reason == Reason::NO_ERROR { + // NO_ERROR is only used for graceful shutdowns... + trace!("interpretting NO_ERROR user error as graceful_shutdown"); + self.conn.graceful_shutdown(); + } else { + trace!("abruptly shutting down with {:?}", reason); + self.conn.abrupt_shutdown(reason); + } + self.closing = Some(err); + break; + } + } + + // When the service is ready, accepts an incoming request. + match ready!(self.conn.poll_accept(cx)) { + Some(Ok((req, respond))) => { + trace!("incoming request"); + let content_length = decode_content_length(req.headers()); + let ping = self + .ping + .as_ref() + .map(|ping| ping.0.clone()) + .unwrap_or_else(ping::disabled); + + // Record the headers received + ping.record_non_data(); + + let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); + let fut = H2Stream::new(service.call(req), respond); + exec.execute_h2stream(fut); + } + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + None => { + // no more incoming streams... + if let Some((ref ping, _)) = self.ping { + ping.ensure_not_timed_out()?; + } + + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + } + } + } + + debug_assert!( + self.closing.is_some(), + "poll_server broke loop without closing" + ); + + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + + Poll::Ready(Err(self.closing.take().expect("polled after error"))) + } + + fn poll_ping(&mut self, cx: &mut task::Context<'_>) { + if let Some((_, ref mut estimator)) = self.ping { + match estimator.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + self.conn.set_target_window_size(wnd); + let _ = self.conn.set_initial_window_size(wnd); + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("keep-alive timed out, closing connection"); + self.conn.abrupt_shutdown(h2::Reason::NO_ERROR); + } + Poll::Pending => {} + } + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct H2Stream<F, B> +where + B: Payload, +{ + reply: SendResponse<SendBuf<B::Data>>, + #[pin] + state: H2StreamState<F, B>, +} + +#[pin_project] +enum H2StreamState<F, B> +where + B: Payload, +{ + Service(#[pin] F), + Body(#[pin] PipeToSendStream<B>), +} + +impl<F, B> H2Stream<F, B> +where + B: Payload, +{ + fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> { + H2Stream { + reply: respond, + state: H2StreamState::Service(fut), + } + } +} + +macro_rules! reply { + ($me:expr, $res:expr, $eos:expr) => {{ + match $me.reply.send_response($res, $eos) { + Ok(tx) => tx, + Err(e) => { + debug!("send response error: {}", e); + $me.reply.send_reset(Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + } + }}; +} + +impl<F, B, E> H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: Payload, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + #[project] + fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + let mut me = self.project(); + loop { + #[project] + let next = match me.state.as_mut().project() { + H2StreamState::Service(h) => { + let res = match h.poll(cx) { + Poll::Ready(Ok(r)) => r, + Poll::Pending => { + // Response is not yet ready, so we want to check if the client has sent a + // RST_STREAM frame which would cancel the current request. + if let Poll::Ready(reason) = + me.reply.poll_reset(cx).map_err(crate::Error::new_h2)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_h2(reason.into()))); + } + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let err = crate::Error::new_user_service(e); + warn!("http2 service errored: {}", err); + me.reply.send_reset(err.h2_reason()); + return Poll::Ready(Err(err)); + } + }; + + let (head, body) = res.into_parts(); + let mut res = ::http::Response::from_parts(head, ()); + super::strip_connection_headers(res.headers_mut(), false); + + // set Date header if it isn't already set... + res.headers_mut() + .entry(::http::header::DATE) + .or_insert_with(crate::proto::h1::date::update_and_header_value); + + // automatically set Content-Length from body... + if let Some(len) = body.size_hint().exact() { + headers::set_content_length_if_missing(res.headers_mut(), len); + } + + if !body.is_end_stream() { + let body_tx = reply!(me, res, false); + H2StreamState::Body(PipeToSendStream::new(body, body_tx)) + } else { + reply!(me, res, true); + return Poll::Ready(Ok(())); + } + } + H2StreamState::Body(pipe) => { + return pipe.poll(cx); + } + }; + me.state.set(next); + } + } +} + +impl<F, B, E> Future for H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: Payload, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll2(cx).map(|res| { + if let Err(e) = res { + debug!("stream error: {}", e); + } + }) + } +} diff --git a/third_party/rust/hyper/src/proto/mod.rs b/third_party/rust/hyper/src/proto/mod.rs new file mode 100644 index 0000000000..7268e21265 --- /dev/null +++ b/third_party/rust/hyper/src/proto/mod.rs @@ -0,0 +1,145 @@ +//! Pieces pertaining to the HTTP message protocol. +use http::{HeaderMap, Method, StatusCode, Uri, Version}; + +pub(crate) use self::body_length::DecodedLength; +pub(crate) use self::h1::{dispatch, Conn, ServerTransaction}; + +pub(crate) mod h1; +pub(crate) mod h2; + +/// An Incoming Message head. Includes request/status line, and headers. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct MessageHead<S> { + /// HTTP version of the message. + pub version: Version, + /// Subject (request line or status line) of Incoming message. + pub subject: S, + /// Headers of the Incoming message. + pub headers: HeaderMap, +} + +/// An incoming request message. +pub type RequestHead = MessageHead<RequestLine>; + +#[derive(Debug, Default, PartialEq)] +pub struct RequestLine(pub Method, pub Uri); + +/// An incoming response message. +pub type ResponseHead = MessageHead<StatusCode>; + +#[derive(Debug)] +pub enum BodyLength { + /// Content-Length + Known(u64), + /// Transfer-Encoding: chunked (if h1) + Unknown, +} + +/// Status of when a Disaptcher future completes. +pub(crate) enum Dispatched { + /// Dispatcher completely shutdown connection. + Shutdown, + /// Dispatcher has pending upgrade, and so did not shutdown. + Upgrade(crate::upgrade::Pending), +} + +/// A separate module to encapsulate the invariants of the DecodedLength type. +mod body_length { + use std::fmt; + + #[derive(Clone, Copy, PartialEq, Eq)] + pub(crate) struct DecodedLength(u64); + + const MAX_LEN: u64 = std::u64::MAX - 2; + + impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + #[cfg(test)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + + /// Takes the length as a content-length without other checks. + /// + /// Should only be called if previously confirmed this isn't + /// CLOSE_DELIMITED or CHUNKED. + #[inline] + pub(crate) fn danger_len(self) -> u64 { + debug_assert!(self.0 < Self::CHUNKED.0); + self.0 + } + + /// Converts to an Option<u64> representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option<u64> { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } + + /// Checks the `u64` is within the maximum allowed for content-length. + pub(crate) fn checked_new(len: u64) -> Result<Self, crate::error::Parse> { + if len <= MAX_LEN { + Ok(DecodedLength(len)) + } else { + warn!("content-length bigger than maximum: {} > {}", len, MAX_LEN); + Err(crate::error::Parse::TooLarge) + } + } + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + } + + impl fmt::Debug for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("CLOSE_DELIMITED"), + DecodedLength::CHUNKED => f.write_str("CHUNKED"), + DecodedLength(n) => f.debug_tuple("DecodedLength").field(&n).finish(), + } + } + } + + impl fmt::Display for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("close-delimited"), + DecodedLength::CHUNKED => f.write_str("chunked encoding"), + DecodedLength::ZERO => f.write_str("empty"), + DecodedLength(n) => write!(f, "content-length ({} bytes)", n), + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn sub_if_known() { + let mut len = DecodedLength::new(30); + len.sub_if(20); + + assert_eq!(len.0, 10); + } + + #[test] + fn sub_if_chunked() { + let mut len = DecodedLength::CHUNKED; + len.sub_if(20); + + assert_eq!(len, DecodedLength::CHUNKED); + } + } +} diff --git a/third_party/rust/hyper/src/rt.rs b/third_party/rust/hyper/src/rt.rs new file mode 100644 index 0000000000..4e60139a87 --- /dev/null +++ b/third_party/rust/hyper/src/rt.rs @@ -0,0 +1,8 @@ +//! Runtime components +//! +//! By default, hyper includes the [tokio](https://tokio.rs) runtime. +//! +//! If the `runtime` feature is disabled, the types in this module can be used +//! to plug in other runtimes. + +pub use crate::common::Executor; diff --git a/third_party/rust/hyper/src/server/accept.rs b/third_party/rust/hyper/src/server/accept.rs new file mode 100644 index 0000000000..e56e3acf84 --- /dev/null +++ b/third_party/rust/hyper/src/server/accept.rs @@ -0,0 +1,101 @@ +//! The `Accept` trait and supporting types. +//! +//! This module contains: +//! +//! - The [`Accept`](Accept) trait used to asynchronously accept incoming +//! connections. +//! - Utilities like `poll_fn` to ease creating a custom `Accept`. + +#[cfg(feature = "stream")] +use futures_core::Stream; + +use crate::common::{ + task::{self, Poll}, + Pin, +}; + +/// Asynchronously accept incoming connections. +pub trait Accept { + /// The connection type that can be accepted. + type Conn; + /// The error type that can occur when accepting a connection. + type Error; + + /// Poll to accept the next connection. + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>>; +} + +/// Create an `Accept` with a polling function. +/// +/// # Example +/// +/// ``` +/// use std::task::Poll; +/// use hyper::server::{accept, Server}; +/// +/// # let mock_conn = (); +/// // If we created some mocked connection... +/// let mut conn = Some(mock_conn); +/// +/// // And accept just the mocked conn once... +/// let once = accept::poll_fn(move |cx| { +/// Poll::Ready(conn.take().map(Ok::<_, ()>)) +/// }); +/// +/// let builder = Server::builder(once); +/// ``` +pub fn poll_fn<F, IO, E>(func: F) -> impl Accept<Conn = IO, Error = E> +where + F: FnMut(&mut task::Context<'_>) -> Poll<Option<Result<IO, E>>>, +{ + struct PollFn<F>(F); + + impl<F, IO, E> Accept for PollFn<F> + where + F: FnMut(&mut task::Context<'_>) -> Poll<Option<Result<IO, E>>>, + { + type Conn = IO; + type Error = E; + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + unsafe { (self.get_unchecked_mut().0)(cx) } + } + } + + PollFn(func) +} + +/// Adapt a `Stream` of incoming connections into an `Accept`. +/// +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +pub fn from_stream<S, IO, E>(stream: S) -> impl Accept<Conn = IO, Error = E> +where + S: Stream<Item = Result<IO, E>>, +{ + struct FromStream<S>(S); + + impl<S, IO, E> Accept for FromStream<S> + where + S: Stream<Item = Result<IO, E>>, + { + type Conn = IO; + type Error = E; + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0).poll_next(cx) } + } + } + + FromStream(stream) +} diff --git a/third_party/rust/hyper/src/server/conn.rs b/third_party/rust/hyper/src/server/conn.rs new file mode 100644 index 0000000000..aa8233da46 --- /dev/null +++ b/third_party/rust/hyper/src/server/conn.rs @@ -0,0 +1,1025 @@ +//! Lower-level Server connection API. +//! +//! The types in this module are to provide a lower-level API based around a +//! single connection. Accepting a connection and binding it with a service +//! are not handled at this level. This module provides the building blocks to +//! customize those things externally. +//! +//! If you don't have need to manage connections yourself, consider using the +//! higher-level [Server](super) API. + +use std::error::Error as StdError; +use std::fmt; +use std::mem; +#[cfg(feature = "tcp")] +use std::net::SocketAddr; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::Accept; +use crate::body::{Body, Payload}; +use crate::common::exec::{Exec, H2Exec, NewSvcExec}; +use crate::common::io::Rewind; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::error::{Kind, Parse}; +use crate::proto; +use crate::service::{HttpService, MakeServiceRef}; +use crate::upgrade::Upgraded; + +use self::spawn_all::NewSvcTask; +pub(super) use self::spawn_all::NoopWatcher; +pub(super) use self::spawn_all::Watcher; +pub(super) use self::upgrades::UpgradeableConnection; + +#[cfg(feature = "tcp")] +pub use super::tcp::{AddrIncoming, AddrStream}; + +/// A lower-level configuration of the HTTP protocol. +/// +/// This structure is used to configure options for an HTTP server connection. +/// +/// If you don't have need to manage connections yourself, consider using the +/// higher-level [Server](super) API. +#[derive(Clone, Debug)] +pub struct Http<E = Exec> { + exec: E, + h1_half_close: bool, + h1_keep_alive: bool, + h1_writev: bool, + h2_builder: proto::h2::server::Config, + mode: ConnectionMode, + max_buf_size: Option<usize>, + pipeline_flush: bool, +} + +/// The internal mode of HTTP protocol which indicates the behavior when a parse error occurs. +#[derive(Clone, Debug, PartialEq)] +enum ConnectionMode { + /// Always use HTTP/1 and do not upgrade when a parse error occurs. + H1Only, + /// Always use HTTP/2. + H2Only, + /// Use HTTP/1 and try to upgrade to h2 when a parse error occurs. + Fallback, +} + +/// A stream mapping incoming IOs to new services. +/// +/// Yields `Connecting`s that are futures that should be put on a reactor. +#[must_use = "streams do nothing unless polled"] +#[pin_project] +#[derive(Debug)] +pub(super) struct Serve<I, S, E = Exec> { + #[pin] + incoming: I, + make_service: S, + protocol: Http<E>, +} + +/// A future building a new `Service` to a `Connection`. +/// +/// Wraps the future returned from `MakeService` into one that returns +/// a `Connection`. +#[must_use = "futures do nothing unless polled"] +#[pin_project] +#[derive(Debug)] +pub struct Connecting<I, F, E = Exec> { + #[pin] + future: F, + io: Option<I>, + protocol: Http<E>, +} + +#[must_use = "futures do nothing unless polled"] +#[pin_project] +#[derive(Debug)] +pub(super) struct SpawnAll<I, S, E> { + // TODO: re-add `pub(super)` once rustdoc can handle this. + // + // See https://github.com/rust-lang/rust/issues/64705 + #[pin] + pub serve: Serve<I, S, E>, +} + +/// A future binding a connection with a Service. +/// +/// Polling this future will drive HTTP forward. +#[must_use = "futures do nothing unless polled"] +#[pin_project] +pub struct Connection<T, S, E = Exec> +where + S: HttpService<Body>, +{ + pub(super) conn: Option<ProtoServer<T, S::ResBody, S, E>>, + fallback: Fallback<E>, +} + +#[pin_project] +pub(super) enum ProtoServer<T, B, S, E = Exec> +where + S: HttpService<Body>, + B: Payload, +{ + H1( + #[pin] + proto::h1::Dispatcher< + proto::h1::dispatch::Server<S, Body>, + B, + T, + proto::ServerTransaction, + >, + ), + H2(#[pin] proto::h2::Server<Rewind<T>, S, B, E>), +} + +#[derive(Clone, Debug)] +enum Fallback<E> { + ToHttp2(proto::h2::server::Config, E), + Http1Only, +} + +impl<E> Fallback<E> { + fn to_h2(&self) -> bool { + match *self { + Fallback::ToHttp2(..) => true, + Fallback::Http1Only => false, + } + } +} + +impl<E> Unpin for Fallback<E> {} + +/// Deconstructed parts of a `Connection`. +/// +/// This allows taking apart a `Connection` at a later time, in order to +/// reclaim the IO object, and additional related pieces. +#[derive(Debug)] +pub struct Parts<T, S> { + /// The original IO object used in the handshake. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// If the client sent additional bytes after its last request, and + /// this connection "ended" with an upgrade, the read buffer will contain + /// those bytes. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + /// The `Service` used to serve this connection. + pub service: S, + _inner: (), +} + +// ===== impl Http ===== + +impl Http { + /// Creates a new instance of the HTTP protocol, ready to spawn a server or + /// start accepting connections. + pub fn new() -> Http { + Http { + exec: Exec::Default, + h1_half_close: false, + h1_keep_alive: true, + h1_writev: true, + h2_builder: Default::default(), + mode: ConnectionMode::Fallback, + max_buf_size: None, + pipeline_flush: false, + } + } +} + +impl<E> Http<E> { + /// Sets whether HTTP1 is required. + /// + /// Default is false + pub fn http1_only(&mut self, val: bool) -> &mut Self { + if val { + self.mode = ConnectionMode::H1Only; + } else { + self.mode = ConnectionMode::Fallback; + } + self + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + pub fn http1_half_close(&mut self, val: bool) -> &mut Self { + self.h1_half_close = val; + self + } + + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + pub fn http1_keep_alive(&mut self, val: bool) -> &mut Self { + self.h1_keep_alive = val; + self + } + + // renamed due different semantics of http2 keep alive + #[doc(hidden)] + #[deprecated(note = "renamed to `http1_keep_alive`")] + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + self.http1_keep_alive(val) + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Default is `true`. + #[inline] + pub fn http1_writev(&mut self, val: bool) -> &mut Self { + self.h1_writev = val; + self + } + + /// Sets whether HTTP2 is required. + /// + /// Default is false + pub fn http2_only(&mut self, val: bool) -> &mut Self { + if val { + self.mode = ConnectionMode::H2Only; + } else { + self.mode = ConnectionMode::Fallback; + } + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_stream_window_size = sz; + } + self + } + + /// Sets the max connection-level flow control for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_conn_window_size = sz; + } + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + use proto::h2::SPEC_WINDOW_SIZE; + + self.h2_builder.adaptive_window = enabled; + if enabled { + self.h2_builder.initial_conn_window_size = SPEC_WINDOW_SIZE; + self.h2_builder.initial_stream_window_size = SPEC_WINDOW_SIZE; + } + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + pub fn http2_max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self { + self.h2_builder.max_concurrent_streams = max.into(); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + assert!( + max >= proto::h1::MINIMUM_MAX_BUFFER_SIZE, + "the max_buf_size cannot be smaller than the minimum that h1 specifies." + ); + self.max_buf_size = Some(max); + self + } + + /// Aggregates flushes to better support pipelined responses. + /// + /// Experimental, may have bugs. + /// + /// Default is false. + pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { + self.pipeline_flush = enabled; + self + } + + /// Set the executor used to spawn background tasks. + /// + /// Default uses implicit default (like `tokio::spawn`). + pub fn with_executor<E2>(self, exec: E2) -> Http<E2> { + Http { + exec, + h1_half_close: self.h1_half_close, + h1_keep_alive: self.h1_keep_alive, + h1_writev: self.h1_writev, + h2_builder: self.h2_builder, + mode: self.mode, + max_buf_size: self.max_buf_size, + pipeline_flush: self.pipeline_flush, + } + } + + /// Bind a connection together with a [`Service`](crate::service::Service). + /// + /// This returns a Future that must be polled in order for HTTP to be + /// driven on the connection. + /// + /// # Example + /// + /// ``` + /// # use hyper::{Body, Request, Response}; + /// # use hyper::service::Service; + /// # use hyper::server::conn::Http; + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # async fn run<I, S>(some_io: I, some_service: S) + /// # where + /// # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + /// # S: Service<hyper::Request<Body>, Response=hyper::Response<Body>> + Send + 'static, + /// # S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + /// # S::Future: Send, + /// # { + /// let http = Http::new(); + /// let conn = http.serve_connection(some_io, some_service); + /// + /// if let Err(e) = conn.await { + /// eprintln!("server connection error: {}", e); + /// } + /// # } + /// # fn main() {} + /// ``` + pub fn serve_connection<S, I, Bd>(&self, io: I, service: S) -> Connection<I, S, E> + where + S: HttpService<Body, ResBody = Bd>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bd: Payload, + I: AsyncRead + AsyncWrite + Unpin, + E: H2Exec<S::Future, Bd>, + { + let proto = match self.mode { + ConnectionMode::H1Only | ConnectionMode::Fallback => { + let mut conn = proto::Conn::new(io); + if !self.h1_keep_alive { + conn.disable_keep_alive(); + } + if self.h1_half_close { + conn.set_allow_half_close(); + } + if !self.h1_writev { + conn.set_write_strategy_flatten(); + } + conn.set_flush_pipeline(self.pipeline_flush); + if let Some(max) = self.max_buf_size { + conn.set_max_buf_size(max); + } + let sd = proto::h1::dispatch::Server::new(service); + ProtoServer::H1(proto::h1::Dispatcher::new(sd, conn)) + } + ConnectionMode::H2Only => { + let rewind_io = Rewind::new(io); + let h2 = + proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone()); + ProtoServer::H2(h2) + } + }; + + Connection { + conn: Some(proto), + fallback: if self.mode == ConnectionMode::Fallback { + Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone()) + } else { + Fallback::Http1Only + }, + } + } + + pub(super) fn serve<I, IO, IE, S, Bd>(&self, incoming: I, make_service: S) -> Serve<I, S, E> + where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin, + S: MakeServiceRef<IO, Body, ResBody = Bd>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bd: Payload, + E: H2Exec<<S::Service as HttpService<Body>>::Future, Bd>, + { + Serve { + incoming, + make_service, + protocol: self.clone(), + } + } +} + +// ===== impl Connection ===== + +impl<I, B, S, E> Connection<I, S, E> +where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + B: Payload + 'static, + E: H2Exec<S::Future, B>, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown + /// can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still + /// pending. If called after `Connection::poll` has resolved, this does + /// nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().conn { + Some(ProtoServer::H1(ref mut h1)) => { + h1.disable_keep_alive(); + } + Some(ProtoServer::H2(ref mut h2)) => { + h2.graceful_shutdown(); + } + None => (), + } + } + + /// Return the inner IO object, and additional information. + /// + /// If the IO object has been "rewound" the io will not contain those bytes rewound. + /// This should only be called after `poll_without_shutdown` signals + /// that the connection is "done". Otherwise, it may not have finished + /// flushing all necessary HTTP bytes. + /// + /// # Panics + /// This method will panic if this connection is using an h2 protocol. + pub fn into_parts(self) -> Parts<I, S> { + self.try_into_parts() + .unwrap_or_else(|| panic!("h2 cannot into_inner")) + } + + /// Return the inner IO object, and additional information, if available. + /// + /// This method will return a `None` if this connection is using an h2 protocol. + pub fn try_into_parts(self) -> Option<Parts<I, S>> { + match self.conn.unwrap() { + ProtoServer::H1(h1) => { + let (io, read_buf, dispatch) = h1.into_inner(); + Some(Parts { + io, + read_buf, + service: dispatch.into_service(), + _inner: (), + }) + } + ProtoServer::H2(_h2) => None, + } + } + + /// Poll the connection for completion, but without calling `shutdown` + /// on the underlying IO. + /// + /// This is useful to allow running a connection while doing an HTTP + /// upgrade. Once the upgrade is completed, the connection would be "done", + /// but it is not desired to actually shutdown the IO object. Instead you + /// would take it back using `into_parts`. + /// + /// Use [`poll_fn`](https://docs.rs/futures/0.1.25/futures/future/fn.poll_fn.html) + /// and [`try_ready!`](https://docs.rs/futures/0.1.25/futures/macro.try_ready.html) + /// to work with this function; or use the `without_shutdown` wrapper. + pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> + where + S: Unpin, + S::Future: Unpin, + B: Unpin, + { + loop { + let polled = match *self.conn.as_mut().unwrap() { + ProtoServer::H1(ref mut h1) => h1.poll_without_shutdown(cx), + ProtoServer::H2(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()), + }; + match ready!(polled) { + Ok(()) => return Poll::Ready(Ok(())), + Err(e) => match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.fallback.to_h2() => { + self.upgrade_h2(); + continue; + } + _ => return Poll::Ready(Err(e)), + }, + } + } + } + + /// Prevent shutdown of the underlying IO object at the end of service the request, + /// instead run `into_parts`. This is a convenience wrapper over `poll_without_shutdown`. + pub fn without_shutdown(self) -> impl Future<Output = crate::Result<Parts<I, S>>> + where + S: Unpin, + S::Future: Unpin, + B: Unpin, + { + let mut conn = Some(self); + futures_util::future::poll_fn(move |cx| { + ready!(conn.as_mut().unwrap().poll_without_shutdown(cx))?; + Poll::Ready(Ok(conn.take().unwrap().into_parts())) + }) + } + + fn upgrade_h2(&mut self) { + trace!("Trying to upgrade connection to h2"); + let conn = self.conn.take(); + + let (io, read_buf, dispatch) = match conn.unwrap() { + ProtoServer::H1(h1) => h1.into_inner(), + ProtoServer::H2(_h2) => { + panic!("h2 cannot into_inner"); + } + }; + let mut rewind_io = Rewind::new(io); + rewind_io.rewind(read_buf); + let (builder, exec) = match self.fallback { + Fallback::ToHttp2(ref builder, ref exec) => (builder, exec), + Fallback::Http1Only => unreachable!("upgrade_h2 with Fallback::Http1Only"), + }; + let h2 = proto::h2::Server::new(rewind_io, dispatch.into_service(), builder, exec.clone()); + + debug_assert!(self.conn.is_none()); + self.conn = Some(ProtoServer::H2(h2)); + } + + /// Enable this connection to support higher-level HTTP upgrades. + /// + /// See [the `upgrade` module](crate::upgrade) for more. + pub fn with_upgrades(self) -> UpgradeableConnection<I, S, E> + where + I: Send, + { + UpgradeableConnection { inner: self } + } +} + +impl<I, B, S, E> Future for Connection<I, S, E> +where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin + 'static, + B: Payload + 'static, + E: H2Exec<S::Future, B>, +{ + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(Pin::new(self.conn.as_mut().unwrap()).poll(cx)) { + Ok(done) => { + if let proto::Dispatched::Upgrade(pending) = done { + // With no `Send` bound on `I`, we can't try to do + // upgrades here. In case a user was trying to use + // `Body::on_upgrade` with this API, send a special + // error letting them know about that. + pending.manual(); + } + return Poll::Ready(Ok(())); + } + Err(e) => match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.fallback.to_h2() => { + self.upgrade_h2(); + continue; + } + _ => return Poll::Ready(Err(e)), + }, + } + } + } +} + +impl<I, S> fmt::Debug for Connection<I, S> +where + S: HttpService<Body>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").finish() + } +} +// ===== impl Serve ===== + +impl<I, S, E> Serve<I, S, E> { + /// Get a reference to the incoming stream. + #[inline] + pub fn incoming_ref(&self) -> &I { + &self.incoming + } + + /* + /// Get a mutable reference to the incoming stream. + #[inline] + pub fn incoming_mut(&mut self) -> &mut I { + &mut self.incoming + } + */ + + /// Spawn all incoming connections onto the executor in `Http`. + pub(super) fn spawn_all(self) -> SpawnAll<I, S, E> { + SpawnAll { serve: self } + } +} + +impl<I, IO, IE, S, B, E> Serve<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IO: AsyncRead + AsyncWrite + Unpin, + IE: Into<Box<dyn StdError + Send + Sync>>, + S: MakeServiceRef<IO, Body, ResBody = B>, + B: Payload, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, +{ + fn poll_next_( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<Connecting<IO, S::Future, E>>>> { + let me = self.project(); + match ready!(me.make_service.poll_ready_ref(cx)) { + Ok(()) => (), + Err(e) => { + trace!("make_service closed"); + return Poll::Ready(Some(Err(crate::Error::new_user_make_service(e)))); + } + } + + if let Some(item) = ready!(me.incoming.poll_accept(cx)) { + let io = item.map_err(crate::Error::new_accept)?; + let new_fut = me.make_service.make_service_ref(&io); + Poll::Ready(Some(Ok(Connecting { + future: new_fut, + io: Some(io), + protocol: me.protocol.clone(), + }))) + } else { + Poll::Ready(None) + } + } +} + +// ===== impl Connecting ===== + +impl<I, F, S, FE, E, B> Future for Connecting<I, F, E> +where + I: AsyncRead + AsyncWrite + Unpin, + F: Future<Output = Result<S, FE>>, + S: HttpService<Body, ResBody = B>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + type Output = Result<Connection<I, S, E>, FE>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + let service = ready!(me.future.poll(cx))?; + let io = me.io.take().expect("polled after complete"); + Poll::Ready(Ok(me.protocol.serve_connection(io, service))) + } +} + +// ===== impl SpawnAll ===== + +#[cfg(feature = "tcp")] +impl<S, E> SpawnAll<AddrIncoming, S, E> { + pub(super) fn local_addr(&self) -> SocketAddr { + self.serve.incoming.local_addr() + } +} + +impl<I, S, E> SpawnAll<I, S, E> { + pub(super) fn incoming_ref(&self) -> &I { + self.serve.incoming_ref() + } +} + +impl<I, IO, IE, S, B, E> SpawnAll<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + B: Payload, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, +{ + pub(super) fn poll_watch<W>( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + watcher: &W, + ) -> Poll<crate::Result<()>> + where + E: NewSvcExec<IO, S::Future, S::Service, E, W>, + W: Watcher<IO, S::Service, E>, + { + let mut me = self.project(); + loop { + if let Some(connecting) = ready!(me.serve.as_mut().poll_next_(cx)?) { + let fut = NewSvcTask::new(connecting, watcher.clone()); + me.serve + .as_mut() + .project() + .protocol + .exec + .execute_new_svc(fut); + } else { + return Poll::Ready(Ok(())); + } + } + } +} + +// ===== impl ProtoServer ===== + +impl<T, B, S, E> Future for ProtoServer<T, B, S, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + type Output = crate::Result<proto::Dispatched>; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + #[project] + match self.project() { + ProtoServer::H1(s) => s.poll(cx), + ProtoServer::H2(s) => s.poll(cx), + } + } +} + +pub(crate) mod spawn_all { + use std::error::Error as StdError; + use tokio::io::{AsyncRead, AsyncWrite}; + + use super::{Connecting, UpgradeableConnection}; + use crate::body::{Body, Payload}; + use crate::common::exec::H2Exec; + use crate::common::{task, Future, Pin, Poll, Unpin}; + use crate::service::HttpService; + use pin_project::{pin_project, project}; + + // Used by `SpawnAll` to optionally watch a `Connection` future. + // + // The regular `hyper::Server` just uses a `NoopWatcher`, which does + // not need to watch anything, and so returns the `Connection` untouched. + // + // The `Server::with_graceful_shutdown` needs to keep track of all active + // connections, and signal that they start to shutdown when prompted, so + // it has a `GracefulWatcher` implementation to do that. + pub trait Watcher<I, S: HttpService<Body>, E>: Clone { + type Future: Future<Output = crate::Result<()>>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future; + } + + #[allow(missing_debug_implementations)] + #[derive(Copy, Clone)] + pub struct NoopWatcher; + + impl<I, S, E> Watcher<I, S, E> for NoopWatcher + where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: HttpService<Body>, + E: H2Exec<S::Future, S::ResBody>, + { + type Future = UpgradeableConnection<I, S, E>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future { + conn + } + } + + // This is a `Future<Item=(), Error=()>` spawned to an `Executor` inside + // the `SpawnAll`. By being a nameable type, we can be generic over the + // user's `Service::Future`, and thus an `Executor` can execute it. + // + // Doing this allows for the server to conditionally require `Send` futures, + // depending on the `Executor` configured. + // + // Users cannot import this type, nor the associated `NewSvcExec`. Instead, + // a blanket implementation for `Executor<impl Future>` is sufficient. + + #[pin_project] + #[allow(missing_debug_implementations)] + pub struct NewSvcTask<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> { + #[pin] + state: State<I, N, S, E, W>, + } + + #[pin_project] + pub enum State<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> { + Connecting(#[pin] Connecting<I, N, E>, W), + Connected(#[pin] W::Future), + } + + impl<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> NewSvcTask<I, N, S, E, W> { + pub(super) fn new(connecting: Connecting<I, N, E>, watcher: W) -> Self { + NewSvcTask { + state: State::Connecting(connecting, watcher), + } + } + } + + impl<I, N, S, NE, B, E, W> Future for NewSvcTask<I, N, S, E, W> + where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + N: Future<Output = Result<S, NE>>, + NE: Into<Box<dyn StdError + Send + Sync>>, + S: HttpService<Body, ResBody = B>, + B: Payload, + E: H2Exec<S::Future, B>, + W: Watcher<I, S, E>, + { + type Output = (); + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + // If it weren't for needing to name this type so the `Send` bounds + // could be projected to the `Serve` executor, this could just be + // an `async fn`, and much safer. Woe is me. + + let mut me = self.project(); + loop { + let next = { + #[project] + match me.state.as_mut().project() { + State::Connecting(connecting, watcher) => { + let res = ready!(connecting.poll(cx)); + let conn = match res { + Ok(conn) => conn, + Err(err) => { + let err = crate::Error::new_user_make_service(err); + debug!("connecting error: {}", err); + return Poll::Ready(()); + } + }; + let connected = watcher.watch(conn.with_upgrades()); + State::Connected(connected) + } + State::Connected(future) => { + return future.poll(cx).map(|res| { + if let Err(err) = res { + debug!("connection error: {}", err); + } + }); + } + } + }; + + me.state.set(next); + } + } + } +} + +mod upgrades { + use super::*; + + // A future binding a connection with a Service with Upgrade support. + // + // This type is unnameable outside the crate, and so basically just an + // `impl Future`, without requiring Rust 1.26. + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct UpgradeableConnection<T, S, E> + where + S: HttpService<Body>, + { + pub(super) inner: Connection<T, S, E>, + } + + impl<I, B, S, E> UpgradeableConnection<I, S, E> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + B: Payload + 'static, + E: H2Exec<S::Future, B>, + { + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown + /// can finish. + pub fn graceful_shutdown(mut self: Pin<&mut Self>) { + Pin::new(&mut self.inner).graceful_shutdown() + } + } + + impl<I, B, S, E> Future for UpgradeableConnection<I, S, E> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: Payload + 'static, + E: super::H2Exec<S::Future, B>, + { + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(Pin::new(self.inner.conn.as_mut().unwrap()).poll(cx)) { + Ok(proto::Dispatched::Shutdown) => return Poll::Ready(Ok(())), + Ok(proto::Dispatched::Upgrade(pending)) => { + let h1 = match mem::replace(&mut self.inner.conn, None) { + Some(ProtoServer::H1(h1)) => h1, + _ => unreachable!("Upgrade expects h1"), + }; + + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(io, buf)); + return Poll::Ready(Ok(())); + } + Err(e) => match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.inner.fallback.to_h2() => { + self.inner.upgrade_h2(); + continue; + } + _ => return Poll::Ready(Err(e)), + }, + } + } + } + } +} diff --git a/third_party/rust/hyper/src/server/mod.rs b/third_party/rust/hyper/src/server/mod.rs new file mode 100644 index 0000000000..ed6068c867 --- /dev/null +++ b/third_party/rust/hyper/src/server/mod.rs @@ -0,0 +1,480 @@ +//! HTTP Server +//! +//! A `Server` is created to listen on a port, parse HTTP requests, and hand +//! them off to a `Service`. +//! +//! There are two levels of APIs provide for constructing HTTP servers: +//! +//! - The higher-level [`Server`](Server) type. +//! - The lower-level [`conn`](conn) module. +//! +//! # Server +//! +//! The [`Server`](Server) is main way to start listening for HTTP requests. +//! It wraps a listener with a [`MakeService`](crate::service), and then should +//! be executed to start serving requests. +//! +//! [`Server`](Server) accepts connections in both HTTP1 and HTTP2 by default. +//! +//! ## Example +//! +//! ```no_run +//! use std::convert::Infallible; +//! use std::net::SocketAddr; +//! use hyper::{Body, Request, Response, Server}; +//! use hyper::service::{make_service_fn, service_fn}; +//! +//! async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { +//! Ok(Response::new(Body::from("Hello World"))) +//! } +//! +//! # #[cfg(feature = "runtime")] +//! #[tokio::main] +//! async fn main() { +//! // Construct our SocketAddr to listen on... +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! +//! // And a MakeService to handle each connection... +//! let make_service = make_service_fn(|_conn| async { +//! Ok::<_, Infallible>(service_fn(handle)) +//! }); +//! +//! // Then bind and serve... +//! let server = Server::bind(&addr).serve(make_service); +//! +//! // And run forever... +//! if let Err(e) = server.await { +//! eprintln!("server error: {}", e); +//! } +//! } +//! # #[cfg(not(feature = "runtime"))] +//! # fn main() {} +//! ``` + +pub mod accept; +pub mod conn; +mod shutdown; +#[cfg(feature = "tcp")] +mod tcp; + +use std::error::Error as StdError; +use std::fmt; +#[cfg(feature = "tcp")] +use std::net::{SocketAddr, TcpListener as StdTcpListener}; + +#[cfg(feature = "tcp")] +use std::time::Duration; + +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; + +use self::accept::Accept; +use crate::body::{Body, Payload}; +use crate::common::exec::{Exec, H2Exec, NewSvcExec}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::service::{HttpService, MakeServiceRef}; +// Renamed `Http` as `Http_` for now so that people upgrading don't see an +// error that `hyper::server::Http` is private... +use self::conn::{Http as Http_, NoopWatcher, SpawnAll}; +use self::shutdown::{Graceful, GracefulWatcher}; +#[cfg(feature = "tcp")] +use self::tcp::AddrIncoming; + +/// A listening HTTP server that accepts connections in both HTTP1 and HTTP2 by default. +/// +/// `Server` is a `Future` mapping a bound listener with a set of service +/// handlers. It is built using the [`Builder`](Builder), and the future +/// completes when the server has been shutdown. It should be run by an +/// `Executor`. +#[pin_project] +pub struct Server<I, S, E = Exec> { + #[pin] + spawn_all: SpawnAll<I, S, E>, +} + +/// A builder for a [`Server`](Server). +#[derive(Debug)] +pub struct Builder<I, E = Exec> { + incoming: I, + protocol: Http_<E>, +} + +// ===== impl Server ===== + +impl<I> Server<I, ()> { + /// Starts a [`Builder`](Builder) with the provided incoming stream. + pub fn builder(incoming: I) -> Builder<I> { + Builder { + incoming, + protocol: Http_::new(), + } + } +} + +#[cfg(feature = "tcp")] +impl Server<AddrIncoming, ()> { + /// Binds to the provided address, and returns a [`Builder`](Builder). + /// + /// # Panics + /// + /// This method will panic if binding to the address fails. For a method + /// to bind to an address and return a `Result`, see `Server::try_bind`. + pub fn bind(addr: &SocketAddr) -> Builder<AddrIncoming> { + let incoming = AddrIncoming::new(addr).unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }); + Server::builder(incoming) + } + + /// Tries to bind to the provided address, and returns a [`Builder`](Builder). + pub fn try_bind(addr: &SocketAddr) -> crate::Result<Builder<AddrIncoming>> { + AddrIncoming::new(addr).map(Server::builder) + } + + /// Create a new instance from a `std::net::TcpListener` instance. + pub fn from_tcp(listener: StdTcpListener) -> Result<Builder<AddrIncoming>, crate::Error> { + AddrIncoming::from_std(listener).map(Server::builder) + } +} + +#[cfg(feature = "tcp")] +impl<S, E> Server<AddrIncoming, S, E> { + /// Returns the local address that this server is bound to. + pub fn local_addr(&self) -> SocketAddr { + self.spawn_all.local_addr() + } +} + +impl<I, IO, IE, S, E, B> Server<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, + E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>, +{ + /// Prepares a server to handle graceful shutdown when the provided future + /// completes. + /// + /// # Example + /// + /// ``` + /// # fn main() {} + /// # #[cfg(feature = "tcp")] + /// # async fn run() { + /// # use hyper::{Body, Response, Server, Error}; + /// # use hyper::service::{make_service_fn, service_fn}; + /// # let make_service = make_service_fn(|_| async { + /// # Ok::<_, Error>(service_fn(|_req| async { + /// # Ok::<_, Error>(Response::new(Body::from("Hello World"))) + /// # })) + /// # }); + /// // Make a server from the previous examples... + /// let server = Server::bind(&([127, 0, 0, 1], 3000).into()) + /// .serve(make_service); + /// + /// // Prepare some signal for when the server should start shutting down... + /// let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + /// let graceful = server + /// .with_graceful_shutdown(async { + /// rx.await.ok(); + /// }); + /// + /// // Await the `server` receiving the signal... + /// if let Err(e) = graceful.await { + /// eprintln!("server error: {}", e); + /// } + /// + /// // And later, trigger the signal by calling `tx.send(())`. + /// let _ = tx.send(()); + /// # } + /// ``` + pub fn with_graceful_shutdown<F>(self, signal: F) -> Graceful<I, S, F, E> + where + F: Future<Output = ()>, + { + Graceful::new(self.spawn_all, signal) + } +} + +impl<I, IO, IE, S, B, E> Future for Server<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, + E: NewSvcExec<IO, S::Future, S::Service, E, NoopWatcher>, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.project().spawn_all.poll_watch(cx, &NoopWatcher) + } +} + +impl<I: fmt::Debug, S: fmt::Debug> fmt::Debug for Server<I, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Server") + .field("listener", &self.spawn_all.incoming_ref()) + .finish() + } +} + +// ===== impl Builder ===== + +impl<I, E> Builder<I, E> { + /// Start a new builder, wrapping an incoming stream and low-level options. + /// + /// For a more convenient constructor, see [`Server::bind`](Server::bind). + pub fn new(incoming: I, protocol: Http_<E>) -> Self { + Builder { incoming, protocol } + } + + /// Sets whether to use keep-alive for HTTP/1 connections. + /// + /// Default is `true`. + pub fn http1_keepalive(mut self, val: bool) -> Self { + self.protocol.http1_keep_alive(val); + self + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + pub fn http1_half_close(mut self, val: bool) -> Self { + self.protocol.http1_half_close(val); + self + } + + /// Set the maximum buffer size. + /// + /// Default is ~ 400kb. + pub fn http1_max_buf_size(mut self, val: usize) -> Self { + self.protocol.max_buf_size(val); + self + } + + // Sets whether to bunch up HTTP/1 writes until the read buffer is empty. + // + // This isn't really desirable in most cases, only really being useful in + // silly pipeline benchmarks. + #[doc(hidden)] + pub fn http1_pipeline_flush(mut self, val: bool) -> Self { + self.protocol.pipeline_flush(val); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// # Note + /// + /// Setting this to `false` may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Default is `true`. + pub fn http1_writev(mut self, val: bool) -> Self { + self.protocol.http1_writev(val); + self + } + + /// Sets whether HTTP/1 is required. + /// + /// Default is `false`. + pub fn http1_only(mut self, val: bool) -> Self { + self.protocol.http1_only(val); + self + } + + /// Sets whether HTTP/2 is required. + /// + /// Default is `false`. + pub fn http2_only(mut self, val: bool) -> Self { + self.protocol.http2_only(val); + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn http2_initial_stream_window_size(mut self, sz: impl Into<Option<u32>>) -> Self { + self.protocol.http2_initial_stream_window_size(sz.into()); + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn http2_initial_connection_window_size(mut self, sz: impl Into<Option<u32>>) -> Self { + self.protocol + .http2_initial_connection_window_size(sz.into()); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn http2_adaptive_window(mut self, enabled: bool) -> Self { + self.protocol.http2_adaptive_window(enabled); + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + pub fn http2_max_concurrent_streams(mut self, max: impl Into<Option<u32>>) -> Self { + self.protocol.http2_max_concurrent_streams(max.into()); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval(mut self, interval: impl Into<Option<Duration>>) -> Self { + self.protocol.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> Self { + self.protocol.http2_keep_alive_timeout(timeout); + self + } + + /// Sets the `Executor` to deal with connection tasks. + /// + /// Default is `tokio::spawn`. + pub fn executor<E2>(self, executor: E2) -> Builder<I, E2> { + Builder { + incoming: self.incoming, + protocol: self.protocol.with_executor(executor), + } + } + + /// Consume this `Builder`, creating a [`Server`](Server). + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tcp")] + /// # async fn run() { + /// use hyper::{Body, Error, Response, Server}; + /// use hyper::service::{make_service_fn, service_fn}; + /// + /// // Construct our SocketAddr to listen on... + /// let addr = ([127, 0, 0, 1], 3000).into(); + /// + /// // And a MakeService to handle each connection... + /// let make_svc = make_service_fn(|_| async { + /// Ok::<_, Error>(service_fn(|_req| async { + /// Ok::<_, Error>(Response::new(Body::from("Hello World"))) + /// })) + /// }); + /// + /// // Then bind and serve... + /// let server = Server::bind(&addr) + /// .serve(make_svc); + /// + /// // Run forever-ish... + /// if let Err(err) = server.await { + /// eprintln!("server error: {}", err); + /// } + /// # } + /// ``` + pub fn serve<S, B>(self, new_service: S) -> Server<I, S, E> + where + I: Accept, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<I::Conn, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: NewSvcExec<I::Conn, S::Future, S::Service, E, NoopWatcher>, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, + { + let serve = self.protocol.serve(self.incoming, new_service); + let spawn_all = serve.spawn_all(); + Server { spawn_all } + } +} + +#[cfg(feature = "tcp")] +impl<E> Builder<AddrIncoming, E> { + /// Set whether TCP keepalive messages are enabled on accepted connections. + /// + /// If `None` is specified, keepalive is disabled, otherwise the duration + /// specified will be the time to remain idle before sending TCP keepalive + /// probes. + pub fn tcp_keepalive(mut self, keepalive: Option<Duration>) -> Self { + self.incoming.set_keepalive(keepalive); + self + } + + /// Set the value of `TCP_NODELAY` option for accepted connections. + pub fn tcp_nodelay(mut self, enabled: bool) -> Self { + self.incoming.set_nodelay(enabled); + self + } + + /// Set whether to sleep on accept errors. + /// + /// A possible scenario is that the process has hit the max open files + /// allowed, and so trying to accept a new connection will fail with + /// EMFILE. In some cases, it's preferable to just wait for some time, if + /// the application will likely close some files (or connections), and try + /// to accept the connection again. If this option is true, the error will + /// be logged at the error level, since it is still a big deal, and then + /// the listener will sleep for 1 second. + /// + /// In other cases, hitting the max open files should be treat similarly + /// to being out-of-memory, and simply error (and shutdown). Setting this + /// option to false will allow that. + /// + /// For more details see [`AddrIncoming::set_sleep_on_errors`] + pub fn tcp_sleep_on_accept_errors(mut self, val: bool) -> Self { + self.incoming.set_sleep_on_errors(val); + self + } +} diff --git a/third_party/rust/hyper/src/server/shutdown.rs b/third_party/rust/hyper/src/server/shutdown.rs new file mode 100644 index 0000000000..1dc668ce40 --- /dev/null +++ b/third_party/rust/hyper/src/server/shutdown.rs @@ -0,0 +1,119 @@ +use std::error::Error as StdError; + +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::conn::{SpawnAll, UpgradeableConnection, Watcher}; +use super::Accept; +use crate::body::{Body, Payload}; +use crate::common::drain::{self, Draining, Signal, Watch, Watching}; +use crate::common::exec::{H2Exec, NewSvcExec}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::service::{HttpService, MakeServiceRef}; + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct Graceful<I, S, F, E> { + #[pin] + state: State<I, S, F, E>, +} + +#[pin_project] +pub(super) enum State<I, S, F, E> { + Running { + drain: Option<(Signal, Watch)>, + #[pin] + spawn_all: SpawnAll<I, S, E>, + #[pin] + signal: F, + }, + Draining(Draining), +} + +impl<I, S, F, E> Graceful<I, S, F, E> { + pub(super) fn new(spawn_all: SpawnAll<I, S, E>, signal: F) -> Self { + let drain = Some(drain::channel()); + Graceful { + state: State::Running { + drain, + spawn_all, + signal, + }, + } + } +} + +impl<I, IO, IE, S, B, F, E> Future for Graceful<I, S, F, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + F: Future<Output = ()>, + E: H2Exec<<S::Service as HttpService<Body>>::Future, B>, + E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>, +{ + type Output = crate::Result<()>; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + let next = { + #[project] + match me.state.as_mut().project() { + State::Running { + drain, + spawn_all, + signal, + } => match signal.poll(cx) { + Poll::Ready(()) => { + debug!("signal received, starting graceful shutdown"); + let sig = drain.take().expect("drain channel").0; + State::Draining(sig.drain()) + } + Poll::Pending => { + let watch = drain.as_ref().expect("drain channel").1.clone(); + return spawn_all.poll_watch(cx, &GracefulWatcher(watch)); + } + }, + State::Draining(ref mut draining) => { + return Pin::new(draining).poll(cx).map(Ok); + } + } + }; + me.state.set(next); + } + } +} + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub struct GracefulWatcher(Watch); + +impl<I, S, E> Watcher<I, S, E> for GracefulWatcher +where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: HttpService<Body>, + E: H2Exec<S::Future, S::ResBody>, +{ + type Future = + Watching<UpgradeableConnection<I, S, E>, fn(Pin<&mut UpgradeableConnection<I, S, E>>)>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future { + self.0.clone().watch(conn, on_drain) + } +} + +fn on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>) +where + S: HttpService<Body>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + S::ResBody: Payload + 'static, + E: H2Exec<S::Future, S::ResBody>, +{ + conn.graceful_shutdown() +} diff --git a/third_party/rust/hyper/src/server/tcp.rs b/third_party/rust/hyper/src/server/tcp.rs new file mode 100644 index 0000000000..b823818693 --- /dev/null +++ b/third_party/rust/hyper/src/server/tcp.rs @@ -0,0 +1,299 @@ +use std::fmt; +use std::io; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::time::Duration; + +use futures_util::FutureExt as _; +use tokio::net::TcpListener; +use tokio::time::Delay; + +use crate::common::{task, Future, Pin, Poll}; + +pub use self::addr_stream::AddrStream; +use super::Accept; + +/// A stream of connections from binding to an address. +#[must_use = "streams do nothing unless polled"] +pub struct AddrIncoming { + addr: SocketAddr, + listener: TcpListener, + sleep_on_errors: bool, + tcp_keepalive_timeout: Option<Duration>, + tcp_nodelay: bool, + timeout: Option<Delay>, +} + +impl AddrIncoming { + pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> { + let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?; + + AddrIncoming::from_std(std_listener) + } + + pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> { + let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?; + let addr = listener.local_addr().map_err(crate::Error::new_listen)?; + Ok(AddrIncoming { + listener, + addr, + sleep_on_errors: true, + tcp_keepalive_timeout: None, + tcp_nodelay: false, + timeout: None, + }) + } + + /// Creates a new `AddrIncoming` binding to provided socket address. + pub fn bind(addr: &SocketAddr) -> crate::Result<Self> { + AddrIncoming::new(addr) + } + + /// Get the local address bound to this listener. + pub fn local_addr(&self) -> SocketAddr { + self.addr + } + + /// Set whether TCP keepalive messages are enabled on accepted connections. + /// + /// If `None` is specified, keepalive is disabled, otherwise the duration + /// specified will be the time to remain idle before sending TCP keepalive + /// probes. + pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self { + self.tcp_keepalive_timeout = keepalive; + self + } + + /// Set the value of `TCP_NODELAY` option for accepted connections. + pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self { + self.tcp_nodelay = enabled; + self + } + + /// Set whether to sleep on accept errors. + /// + /// A possible scenario is that the process has hit the max open files + /// allowed, and so trying to accept a new connection will fail with + /// `EMFILE`. In some cases, it's preferable to just wait for some time, if + /// the application will likely close some files (or connections), and try + /// to accept the connection again. If this option is `true`, the error + /// will be logged at the `error` level, since it is still a big deal, + /// and then the listener will sleep for 1 second. + /// + /// In other cases, hitting the max open files should be treat similarly + /// to being out-of-memory, and simply error (and shutdown). Setting + /// this option to `false` will allow that. + /// + /// Default is `true`. + pub fn set_sleep_on_errors(&mut self, val: bool) { + self.sleep_on_errors = val; + } + + fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> { + // Check if a previous timeout is active that was set by IO errors. + if let Some(ref mut to) = self.timeout { + match Pin::new(to).poll(cx) { + Poll::Ready(()) => {} + Poll::Pending => return Poll::Pending, + } + } + self.timeout = None; + + let accept = self.listener.accept(); + futures_util::pin_mut!(accept); + + loop { + match accept.poll_unpin(cx) { + Poll::Ready(Ok((socket, addr))) => { + if let Some(dur) = self.tcp_keepalive_timeout { + if let Err(e) = socket.set_keepalive(Some(dur)) { + trace!("error trying to set TCP keepalive: {}", e); + } + } + if let Err(e) = socket.set_nodelay(self.tcp_nodelay) { + trace!("error trying to set TCP nodelay: {}", e); + } + return Poll::Ready(Ok(AddrStream::new(socket, addr))); + } + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + // Connection errors can be ignored directly, continue by + // accepting the next request. + if is_connection_error(&e) { + debug!("accepted connection already errored: {}", e); + continue; + } + + if self.sleep_on_errors { + error!("accept error: {}", e); + + // Sleep 1s. + let mut timeout = tokio::time::delay_for(Duration::from_secs(1)); + + match Pin::new(&mut timeout).poll(cx) { + Poll::Ready(()) => { + // Wow, it's been a second already? Ok then... + continue; + } + Poll::Pending => { + self.timeout = Some(timeout); + return Poll::Pending; + } + } + } else { + return Poll::Ready(Err(e)); + } + } + } + } + } +} + +impl Accept for AddrIncoming { + type Conn = AddrStream; + type Error = io::Error; + + fn poll_accept( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + let result = ready!(self.poll_next_(cx)); + Poll::Ready(Some(result)) + } +} + +/// This function defines errors that are per-connection. Which basically +/// means that if we get this error from `accept()` system call it means +/// next connection might be ready to be accepted. +/// +/// All other errors will incur a timeout before next `accept()` is performed. +/// The timeout is useful to handle resource exhaustion errors like ENFILE +/// and EMFILE. Otherwise, could enter into tight loop. +fn is_connection_error(e: &io::Error) -> bool { + match e.kind() { + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset => true, + _ => false, + } +} + +impl fmt::Debug for AddrIncoming { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AddrIncoming") + .field("addr", &self.addr) + .field("sleep_on_errors", &self.sleep_on_errors) + .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout) + .field("tcp_nodelay", &self.tcp_nodelay) + .finish() + } +} + +mod addr_stream { + use bytes::{Buf, BufMut}; + use std::io; + use std::net::SocketAddr; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; + + use crate::common::{task, Pin, Poll}; + + /// A transport returned yieled by `AddrIncoming`. + #[derive(Debug)] + pub struct AddrStream { + inner: TcpStream, + pub(super) remote_addr: SocketAddr, + } + + impl AddrStream { + pub(super) fn new(tcp: TcpStream, addr: SocketAddr) -> AddrStream { + AddrStream { + inner: tcp, + remote_addr: addr, + } + } + + /// Returns the remote (peer) address of this connection. + #[inline] + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } + + /// Consumes the AddrStream and returns the underlying IO object + #[inline] + pub fn into_inner(self) -> TcpStream { + self.inner + } + + /// Attempt to receive data on the socket, without removing that data + /// from the queue, registering the current task for wakeup if data is + /// not yet available. + pub fn poll_peek( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + self.inner.poll_peek(cx, buf) + } + } + + impl AsyncRead for AddrStream { + unsafe fn prepare_uninitialized_buffer( + &self, + buf: &mut [std::mem::MaybeUninit<u8>], + ) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } + + #[inline] + fn poll_read_buf<B: BufMut>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_read_buf(cx, buf) + } + } + + impl AsyncWrite for AddrStream { + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + #[inline] + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write_buf(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + // TCP flush is a noop + Poll::Ready(Ok(())) + } + + #[inline] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + } +} diff --git a/third_party/rust/hyper/src/service/http.rs b/third_party/rust/hyper/src/service/http.rs new file mode 100644 index 0000000000..9c91f652c8 --- /dev/null +++ b/third_party/rust/hyper/src/service/http.rs @@ -0,0 +1,58 @@ +use std::error::Error as StdError; + +use crate::body::Payload; +use crate::common::{task, Future, Poll}; +use crate::{Request, Response}; + +/// An asynchronous function from `Request` to `Response`. +pub trait HttpService<ReqBody>: sealed::Sealed<ReqBody> { + /// The `Payload` body of the `http::Response`. + type ResBody: Payload; + + /// The error type that can occur within this `Service`. + /// + /// Note: Returning an `Error` to a hyper server will cause the connection + /// to be abruptly aborted. In most cases, it is better to return a `Response` + /// with a 4xx or 5xx status code. + type Error: Into<Box<dyn StdError + Send + Sync>>; + + /// The `Future` returned by this `Service`. + type Future: Future<Output = Result<Response<Self::ResBody>, Self::Error>>; + + #[doc(hidden)] + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + + #[doc(hidden)] + fn call(&mut self, req: Request<ReqBody>) -> Self::Future; +} + +impl<T, B1, B2> HttpService<B1> for T +where + T: tower_service::Service<Request<B1>, Response = Response<B2>>, + B2: Payload, + T::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type ResBody = B2; + + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + tower_service::Service::poll_ready(self, cx) + } + + fn call(&mut self, req: Request<B1>) -> Self::Future { + tower_service::Service::call(self, req) + } +} + +impl<T, B1, B2> sealed::Sealed<B1> for T +where + T: tower_service::Service<Request<B1>, Response = Response<B2>>, + B2: Payload, +{ +} + +mod sealed { + pub trait Sealed<T> {} +} diff --git a/third_party/rust/hyper/src/service/make.rs b/third_party/rust/hyper/src/service/make.rs new file mode 100644 index 0000000000..490992e118 --- /dev/null +++ b/third_party/rust/hyper/src/service/make.rs @@ -0,0 +1,186 @@ +use std::error::Error as StdError; +use std::fmt; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{HttpService, Service}; +use crate::body::Payload; +use crate::common::{task, Future, Poll}; + +// The same "trait alias" as tower::MakeConnection, but inlined to reduce +// dependencies. +pub trait MakeConnection<Target>: self::sealed::Sealed<(Target,)> { + type Connection: AsyncRead + AsyncWrite; + type Error; + type Future: Future<Output = Result<Self::Connection, Self::Error>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + fn make_connection(&mut self, target: Target) -> Self::Future; +} + +impl<S, Target> self::sealed::Sealed<(Target,)> for S where S: Service<Target> {} + +impl<S, Target> MakeConnection<Target> for S +where + S: Service<Target>, + S::Response: AsyncRead + AsyncWrite, +{ + type Connection = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Service::poll_ready(self, cx) + } + + fn make_connection(&mut self, target: Target) -> Self::Future { + Service::call(self, target) + } +} + +// Just a sort-of "trait alias" of `MakeService`, not to be implemented +// by anyone, only used as bounds. +pub trait MakeServiceRef<Target, ReqBody>: self::sealed::Sealed<(Target, ReqBody)> { + type ResBody: Payload; + type Error: Into<Box<dyn StdError + Send + Sync>>; + type Service: HttpService<ReqBody, ResBody = Self::ResBody, Error = Self::Error>; + type MakeError: Into<Box<dyn StdError + Send + Sync>>; + type Future: Future<Output = Result<Self::Service, Self::MakeError>>; + + // Acting like a #[non_exhaustive] for associated types of this trait. + // + // Basically, no one outside of hyper should be able to set this type + // or declare bounds on it, so it should prevent people from creating + // trait objects or otherwise writing code that requires using *all* + // of the associated types. + // + // Why? So we can add new associated types to this alias in the future, + // if necessary. + type __DontNameMe: self::sealed::CantImpl; + + fn poll_ready_ref(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::MakeError>>; + + fn make_service_ref(&mut self, target: &Target) -> Self::Future; +} + +impl<T, Target, E, ME, S, F, IB, OB> MakeServiceRef<Target, IB> for T +where + T: for<'a> Service<&'a Target, Error = ME, Response = S, Future = F>, + E: Into<Box<dyn StdError + Send + Sync>>, + ME: Into<Box<dyn StdError + Send + Sync>>, + S: HttpService<IB, ResBody = OB, Error = E>, + F: Future<Output = Result<S, ME>>, + IB: Payload, + OB: Payload, +{ + type Error = E; + type Service = S; + type ResBody = OB; + type MakeError = ME; + type Future = F; + + type __DontNameMe = self::sealed::CantName; + + fn poll_ready_ref(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::MakeError>> { + self.poll_ready(cx) + } + + fn make_service_ref(&mut self, target: &Target) -> Self::Future { + self.call(target) + } +} + +impl<T, Target, S, B1, B2> self::sealed::Sealed<(Target, B1)> for T +where + T: for<'a> Service<&'a Target, Response = S>, + S: HttpService<B1, ResBody = B2>, + B1: Payload, + B2: Payload, +{ +} + +/// Create a `MakeService` from a function. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "runtime")] +/// # async fn run() { +/// use std::convert::Infallible; +/// use hyper::{Body, Request, Response, Server}; +/// use hyper::server::conn::AddrStream; +/// use hyper::service::{make_service_fn, service_fn}; +/// +/// let addr = ([127, 0, 0, 1], 3000).into(); +/// +/// let make_svc = make_service_fn(|socket: &AddrStream| { +/// let remote_addr = socket.remote_addr(); +/// async move { +/// Ok::<_, Infallible>(service_fn(move |_: Request<Body>| async move { +/// Ok::<_, Infallible>( +/// Response::new(Body::from(format!("Hello, {}!", remote_addr))) +/// ) +/// })) +/// } +/// }); +/// +/// // Then bind and serve... +/// let server = Server::bind(&addr) +/// .serve(make_svc); +/// +/// // Finally, spawn `server` onto an Executor... +/// if let Err(e) = server.await { +/// eprintln!("server error: {}", e); +/// } +/// # } +/// # fn main() {} +/// ``` +pub fn make_service_fn<F, Target, Ret>(f: F) -> MakeServiceFn<F> +where + F: FnMut(&Target) -> Ret, + Ret: Future, +{ + MakeServiceFn { f } +} + +/// `MakeService` returned from [`make_service_fn`] +#[derive(Clone, Copy)] +pub struct MakeServiceFn<F> { + f: F, +} + +impl<'t, F, Ret, Target, Svc, MkErr> Service<&'t Target> for MakeServiceFn<F> +where + F: FnMut(&Target) -> Ret, + Ret: Future<Output = Result<Svc, MkErr>>, + MkErr: Into<Box<dyn StdError + Send + Sync>>, +{ + type Error = MkErr; + type Response = Svc; + type Future = Ret; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, target: &'t Target) -> Self::Future { + (self.f)(target) + } +} + +impl<F> fmt::Debug for MakeServiceFn<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MakeServiceFn").finish() + } +} + +mod sealed { + pub trait Sealed<X> {} + + pub trait CantImpl {} + + #[allow(missing_debug_implementations)] + pub enum CantName {} + + impl CantImpl for CantName {} +} diff --git a/third_party/rust/hyper/src/service/mod.rs b/third_party/rust/hyper/src/service/mod.rs new file mode 100644 index 0000000000..bb5a77406f --- /dev/null +++ b/third_party/rust/hyper/src/service/mod.rs @@ -0,0 +1,49 @@ +//! Asynchronous Services +//! +//! A [`Service`](Service) is a trait representing an asynchronous +//! function of a request to a response. It's similar to +//! `async fn(Request) -> Result<Response, Error>`. +//! +//! The argument and return value isn't strictly required to be for HTTP. +//! Therefore, hyper uses several "trait aliases" to reduce clutter around +//! bounds. These are: +//! +//! - `HttpService`: This is blanketly implemented for all types that +//! implement `Service<http::Request<B1>, Response = http::Response<B2>>`. +//! - `MakeService`: When a `Service` returns a new `Service` as its "response", +//! we consider it a `MakeService`. Again, blanketly implemented in those cases. +//! - `MakeConnection`: A `Service` that returns a "connection", a type that +//! implements `AsyncRead` and `AsyncWrite`. +//! +//! # HttpService +//! +//! In hyper, especially in the server setting, a `Service` is usually bound +//! to a single connection. It defines how to respond to **all** requests that +//! connection will receive. +//! +//! While it's possible to implement `Service` for a type manually, the helper +//! [`service_fn`](service_fn) should be sufficient for most cases. +//! +//! # MakeService +//! +//! Since a `Service` is bound to a single connection, a [`Server`](crate::Server) +//! needs a way to make them as it accepts connections. This is what a +//! `MakeService` does. +//! +//! Resources that need to be shared by all `Service`s can be put into a +//! `MakeService`, and then passed to individual `Service`s when `call` +//! is called. + +pub use tower_service::Service; + +mod http; +mod make; +mod oneshot; +mod util; + +pub(crate) use self::http::HttpService; +pub(crate) use self::make::{MakeConnection, MakeServiceRef}; +pub(crate) use self::oneshot::{oneshot, Oneshot}; + +pub use self::make::make_service_fn; +pub use self::util::service_fn; diff --git a/third_party/rust/hyper/src/service/oneshot.rs b/third_party/rust/hyper/src/service/oneshot.rs new file mode 100644 index 0000000000..94f4b43a80 --- /dev/null +++ b/third_party/rust/hyper/src/service/oneshot.rs @@ -0,0 +1,71 @@ +// TODO: Eventually to be replaced with tower_util::Oneshot. + +use std::marker::Unpin; +use std::mem; + +use tower_service::Service; + +use crate::common::{task, Future, Pin, Poll}; + +pub(crate) fn oneshot<S, Req>(svc: S, req: Req) -> Oneshot<S, Req> +where + S: Service<Req>, +{ + Oneshot { + state: State::NotReady(svc, req), + } +} + +// A `Future` consuming a `Service` and request, waiting until the `Service` +// is ready, and then calling `Service::call` with the request, and +// waiting for that `Future`. +#[allow(missing_debug_implementations)] +pub struct Oneshot<S: Service<Req>, Req> { + state: State<S, Req>, +} + +enum State<S: Service<Req>, Req> { + NotReady(S, Req), + Called(S::Future), + Tmp, +} + +// Unpin is projected to S::Future, but never S. +impl<S, Req> Unpin for Oneshot<S, Req> +where + S: Service<Req>, + S::Future: Unpin, +{ +} + +impl<S, Req> Future for Oneshot<S, Req> +where + S: Service<Req>, +{ + type Output = Result<S::Response, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + // Safety: The service's future is never moved once we get one. + let mut me = unsafe { Pin::get_unchecked_mut(self) }; + + loop { + match me.state { + State::NotReady(ref mut svc, _) => { + ready!(svc.poll_ready(cx))?; + // fallthrough out of the match's borrow + } + State::Called(ref mut fut) => { + return unsafe { Pin::new_unchecked(fut) }.poll(cx); + } + State::Tmp => unreachable!(), + } + + match mem::replace(&mut me.state, State::Tmp) { + State::NotReady(mut svc, req) => { + me.state = State::Called(svc.call(req)); + } + _ => unreachable!(), + } + } + } +} diff --git a/third_party/rust/hyper/src/service/util.rs b/third_party/rust/hyper/src/service/util.rs new file mode 100644 index 0000000000..be597f7a02 --- /dev/null +++ b/third_party/rust/hyper/src/service/util.rs @@ -0,0 +1,84 @@ +use std::error::Error as StdError; +use std::fmt; +use std::marker::PhantomData; + +use crate::body::Payload; +use crate::common::{task, Future, Poll}; +use crate::{Request, Response}; + +/// Create a `Service` from a function. +/// +/// # Example +/// +/// ``` +/// use hyper::{Body, Request, Response, Version}; +/// use hyper::service::service_fn; +/// +/// let service = service_fn(|req: Request<Body>| async move { +/// if req.version() == Version::HTTP_11 { +/// Ok(Response::new(Body::from("Hello World"))) +/// } else { +/// // Note: it's usually better to return a Response +/// // with an appropriate StatusCode instead of an Err. +/// Err("not HTTP/1.1, abort connection") +/// } +/// }); +/// ``` +pub fn service_fn<F, R, S>(f: F) -> ServiceFn<F, R> +where + F: FnMut(Request<R>) -> S, + S: Future, +{ + ServiceFn { + f, + _req: PhantomData, + } +} + +/// Service returned by [`service_fn`] +pub struct ServiceFn<F, R> { + f: F, + _req: PhantomData<fn(R)>, +} + +impl<F, ReqBody, Ret, ResBody, E> tower_service::Service<crate::Request<ReqBody>> + for ServiceFn<F, ReqBody> +where + F: FnMut(Request<ReqBody>) -> Ret, + ReqBody: Payload, + Ret: Future<Output = Result<Response<ResBody>, E>>, + E: Into<Box<dyn StdError + Send + Sync>>, + ResBody: Payload, +{ + type Response = crate::Response<ResBody>; + type Error = E; + type Future = Ret; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + (self.f)(req) + } +} + +impl<F, R> fmt::Debug for ServiceFn<F, R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("impl Service").finish() + } +} + +impl<F, R> Clone for ServiceFn<F, R> +where + F: Clone, +{ + fn clone(&self) -> Self { + ServiceFn { + f: self.f.clone(), + _req: PhantomData, + } + } +} + +impl<F, R> Copy for ServiceFn<F, R> where F: Copy {} diff --git a/third_party/rust/hyper/src/upgrade.rs b/third_party/rust/hyper/src/upgrade.rs new file mode 100644 index 0000000000..55f390431f --- /dev/null +++ b/third_party/rust/hyper/src/upgrade.rs @@ -0,0 +1,365 @@ +//! HTTP Upgrades +//! +//! See [this example][example] showing how upgrades work with both +//! Clients and Servers. +//! +//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs + +use std::any::TypeId; +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::marker::Unpin; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::oneshot; + +use crate::common::io::Rewind; +use crate::common::{task, Future, Pin, Poll}; + +/// An upgraded HTTP connection. +/// +/// This type holds a trait object internally of the original IO that +/// was used to speak HTTP before the upgrade. It can be used directly +/// as a `Read` or `Write` for convenience. +/// +/// Alternatively, if the exact type is known, this can be deconstructed +/// into its parts. +pub struct Upgraded { + io: Rewind<Box<dyn Io + Send>>, +} + +/// A future for a possible HTTP upgrade. +/// +/// If no upgrade was available, or it doesn't succeed, yields an `Error`. +pub struct OnUpgrade { + rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, +} + +/// The deconstructed parts of an [`Upgraded`](Upgraded) type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +pub(crate) struct Pending { + tx: oneshot::Sender<crate::Result<Upgraded>>, +} + +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected(()); + +pub(crate) fn pending() -> (Pending, OnUpgrade) { + let (tx, rx) = oneshot::channel(); + (Pending { tx }, OnUpgrade { rx: Some(rx) }) +} + +// ===== impl Upgraded ===== + +impl Upgraded { + pub(crate) fn new<T>(io: T, read_buf: Bytes) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Upgraded { + io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf), + } + } + + /// Tries to downcast the internal trait object to the type passed. + /// + /// On success, returns the downcasted parts. On error, returns the + /// `Upgraded` back. + pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { + let (io, buf) = self.io.into_inner(); + match io.__hyper_downcast::<ForwardsWriteBuf<T>>() { + Ok(t) => Ok(Parts { + io: t.0, + read_buf: buf, + _inner: (), + }), + Err(io) => Err(Upgraded { + io: Rewind::new_buffered(io, buf), + }), + } + } +} + +impl AsyncRead for Upgraded { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_shutdown(cx) + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +// ===== impl OnUpgrade ===== + +impl OnUpgrade { + pub(crate) fn none() -> Self { + OnUpgrade { rx: None } + } + + pub(crate) fn is_none(&self) -> bool { + self.rx.is_none() + } +} + +impl Future for OnUpgrade { + type Output = Result<Upgraded, crate::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.rx { + Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(upgraded)) => Ok(upgraded), + Ok(Err(err)) => Err(err), + Err(_oneshot_canceled) => { + Err(crate::Error::new_canceled().with(UpgradeExpected(()))) + } + }), + None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), + } + } +} + +impl fmt::Debug for OnUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnUpgrade").finish() + } +} + +// ===== impl Pending ===== + +impl Pending { + pub(crate) fn fulfill(self, upgraded: Upgraded) { + trace!("pending upgrade fulfill"); + let _ = self.tx.send(Ok(upgraded)); + } + + /// Don't fulfill the pending Upgrade, but instead signal that + /// upgrades are handled manually. + pub(crate) fn manual(self) { + trace!("pending upgrade handled manually"); + let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); + } +} + +// ===== impl UpgradeExpected ===== + +impl fmt::Display for UpgradeExpected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "upgrade expected but not completed") + } +} + +impl StdError for UpgradeExpected {} + +// ===== impl Io ===== + +struct ForwardsWriteBuf<T>(T); + +pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut dyn Buf, + ) -> Poll<io::Result<usize>>; + + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::<Self>() + } +} + +impl dyn Io + Send { + fn __hyper_is<T: Io>(&self) -> bool { + let t = TypeId::of::<T>(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { + if self.__hyper_is::<T>() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut dyn Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write_buf(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + mut buf: &mut dyn Buf, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write_buf(cx, &mut buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[test] + fn upgraded_downcast() { + let upgraded = Upgraded::new(Mock, Bytes::new()); + + let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); + + upgraded.downcast::<Mock>().unwrap(); + } + + #[tokio::test] + async fn upgraded_forwards_write_buf() { + // sanity check that the underlying IO implements write_buf + Mock.write_buf(&mut "hello".as_bytes()).await.unwrap(); + + let mut upgraded = Upgraded::new(Mock, Bytes::new()); + upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap(); + } + + // TODO: replace with tokio_test::io when it can test write_buf + struct Mock; + + impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + unreachable!("Mock::poll_read") + } + } + + impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &[u8], + ) -> Poll<io::Result<usize>> { + panic!("poll_write shouldn't be called"); + } + + fn poll_write_buf<B: Buf>( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + let n = buf.remaining(); + buf.advance(n); + Poll::Ready(Ok(n)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_flush") + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_shutdown") + } + } +} |