diff options
Diffstat (limited to '')
68 files changed, 25124 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/body/aggregate.rs b/third_party/rust/hyper/src/body/aggregate.rs new file mode 100644 index 0000000000..99662419d3 --- /dev/null +++ b/third_party/rust/hyper/src/body/aggregate.rs @@ -0,0 +1,31 @@ +use bytes::Buf; + +use super::HttpBody; +use crate::common::buf::BufList; + +/// Aggregate the data buffers from a body asynchronously. +/// +/// The returned `impl Buf` groups the `Buf`s from the `HttpBody` without +/// copying them. This is ideal if you don't require a contiguous buffer. +/// +/// # Note +/// +/// Care needs to be taken if the remote is untrusted. The function doesn't implement any length +/// checks and an malicious peer might make it consume arbitrary amounts of memory. Checking the +/// `Content-Length` is a possibility, but it is not strictly mandated to be present. +pub async fn aggregate<T>(body: T) -> Result<impl Buf, T::Error> +where + T: HttpBody, +{ + let mut bufs = BufList::new(); + + futures_util::pin_mut!(body); + while let Some(buf) = body.data().await { + let buf = buf?; + if buf.has_remaining() { + bufs.push(buf); + } + } + + Ok(bufs) +} diff --git a/third_party/rust/hyper/src/body/body.rs b/third_party/rust/hyper/src/body/body.rs new file mode 100644 index 0000000000..9dc1a034f9 --- /dev/null +++ b/third_party/rust/hyper/src/body/body.rs @@ -0,0 +1,785 @@ +use std::borrow::Cow; +#[cfg(feature = "stream")] +use std::error::Error as StdError; +use std::fmt; + +use bytes::Bytes; +use futures_channel::mpsc; +use futures_channel::oneshot; +use futures_core::Stream; // for mpsc::Receiver +#[cfg(feature = "stream")] +use futures_util::TryStreamExt; +use http::HeaderMap; +use http_body::{Body as HttpBody, SizeHint}; + +use super::DecodedLength; +#[cfg(feature = "stream")] +use crate::common::sync_wrapper::SyncWrapper; +use crate::common::Future; +#[cfg(all(feature = "client", any(feature = "http1", feature = "http2")))] +use crate::common::Never; +use crate::common::{task, watch, Pin, Poll}; +#[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] +use crate::proto::h2::ping; + +type BodySender = mpsc::Sender<Result<Bytes, crate::Error>>; +type TrailersSender = oneshot::Sender<HeaderMap>; + +/// A stream of `Bytes`, used when receiving bodies. +/// +/// A good default [`HttpBody`](crate::body::HttpBody) to use in many +/// applications. +/// +/// Note: To read the full body, use [`body::to_bytes`](crate::body::to_bytes) +/// or [`body::aggregate`](crate::body::aggregate). +#[must_use = "streams do nothing unless polled"] +pub struct Body { + kind: Kind, + /// Keep the extra bits in an `Option<Box<Extra>>`, so that + /// Body stays small in the common case (no extras needed). + extra: Option<Box<Extra>>, +} + +enum Kind { + Once(Option<Bytes>), + Chan { + content_length: DecodedLength, + want_tx: watch::Sender, + data_rx: mpsc::Receiver<Result<Bytes, crate::Error>>, + trailers_rx: oneshot::Receiver<HeaderMap>, + }, + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + H2 { + ping: ping::Recorder, + content_length: DecodedLength, + recv: h2::RecvStream, + }, + #[cfg(feature = "ffi")] + Ffi(crate::ffi::UserBody), + #[cfg(feature = "stream")] + Wrapped( + SyncWrapper< + Pin<Box<dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send>>, + >, + ), +} + +struct Extra { + /// Allow the client to pass a future to delay the `Body` from returning + /// EOF. This allows the `Client` to try to put the idle connection + /// back into the pool before the body is "finished". + /// + /// The reason for this is so that creating a new request after finishing + /// streaming the body of a response could sometimes result in creating + /// a brand new connection, since the pool didn't know about the idle + /// connection yet. + delayed_eof: Option<DelayEof>, +} + +#[cfg(all(feature = "client", any(feature = "http1", feature = "http2")))] +type DelayEofUntil = oneshot::Receiver<Never>; + +enum DelayEof { + /// Initial state, stream hasn't seen EOF yet. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + NotEof(DelayEofUntil), + /// Transitions to this state once we've seen `poll` try to + /// return EOF (`None`). This future is then polled, and + /// when it completes, the Body finally returns EOF (`None`). + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Eof(DelayEofUntil), +} + +/// A sender half created through [`Body::channel()`]. +/// +/// Useful when wanting to stream chunks from another thread. +/// +/// ## Body Closing +/// +/// Note that the request body will always be closed normally when the sender is dropped (meaning +/// that the empty terminating chunk will be sent to the remote). If you desire to close the +/// connection with an incomplete response (e.g. in the case of an error during asynchronous +/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion. +/// +/// [`Body::channel()`]: struct.Body.html#method.channel +/// [`Sender::abort()`]: struct.Sender.html#method.abort +#[must_use = "Sender does nothing unless sent on"] +pub struct Sender { + want_rx: watch::Receiver, + data_tx: BodySender, + trailers_tx: Option<TrailersSender>, +} + +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + +impl Body { + /// Create an empty `Body` stream. + /// + /// # Example + /// + /// ``` + /// use hyper::{Body, Request}; + /// + /// // create a `GET /` request + /// let get = Request::new(Body::empty()); + /// ``` + #[inline] + pub fn empty() -> Body { + Body::new(Kind::Once(None)) + } + + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + #[inline] + pub fn channel() -> (Sender, Body) { + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) + } + + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = Body::new(Kind::Chan { + content_length, + want_tx, + data_rx, + trailers_rx, + }); + + (tx, rx) + } + + /// Wrap a futures `Stream` in a box inside `Body`. + /// + /// # Example + /// + /// ``` + /// # use hyper::Body; + /// let chunks: Vec<Result<_, std::io::Error>> = vec![ + /// Ok("hello"), + /// Ok(" "), + /// Ok("world"), + /// ]; + /// + /// let stream = futures_util::stream::iter(chunks); + /// + /// let body = Body::wrap_stream(stream); + /// ``` + /// + /// # Optional + /// + /// This function requires enabling the `stream` feature in your + /// `Cargo.toml`. + #[cfg(feature = "stream")] + #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + pub fn wrap_stream<S, O, E>(stream: S) -> Body + where + S: Stream<Item = Result<O, E>> + Send + 'static, + O: Into<Bytes> + 'static, + E: Into<Box<dyn StdError + Send + Sync>> + 'static, + { + let mapped = stream.map_ok(Into::into).map_err(Into::into); + Body::new(Kind::Wrapped(SyncWrapper::new(Box::pin(mapped)))) + } + + fn new(kind: Kind) -> Body { + Body { kind, extra: None } + } + + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + pub(crate) fn h2( + recv: h2::RecvStream, + mut content_length: DecodedLength, + ping: ping::Recorder, + ) -> Self { + // If the stream is already EOS, then the "unknown length" is clearly + // actually ZERO. + if !content_length.is_exact() && recv.is_end_stream() { + content_length = DecodedLength::ZERO; + } + let body = Body::new(Kind::H2 { + ping, + content_length, + recv, + }); + + body + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + pub(crate) fn delayed_eof(&mut self, fut: DelayEofUntil) { + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(fut)); + } + + fn take_delayed_eof(&mut self) -> Option<DelayEof> { + self.extra + .as_mut() + .and_then(|extra| extra.delayed_eof.take()) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + fn extra_mut(&mut self) -> &mut Extra { + self.extra + .get_or_insert_with(|| Box::new(Extra { delayed_eof: None })) + } + + fn poll_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<crate::Result<Bytes>>> { + match self.take_delayed_eof() { + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Some(DelayEof::NotEof(mut delay)) => match self.poll_inner(cx) { + ok @ Poll::Ready(Some(Ok(..))) | ok @ Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(delay)); + ok + } + Poll::Ready(None) => match Pin::new(&mut delay).poll(cx) { + Poll::Ready(Ok(never)) => match never {}, + Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); + Poll::Pending + } + Poll::Ready(Err(_done)) => Poll::Ready(None), + }, + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + }, + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Some(DelayEof::Eof(mut delay)) => match Pin::new(&mut delay).poll(cx) { + Poll::Ready(Ok(never)) => match never {}, + Poll::Pending => { + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); + Poll::Pending + } + Poll::Ready(Err(_done)) => Poll::Ready(None), + }, + #[cfg(any( + not(any(feature = "http1", feature = "http2")), + not(feature = "client") + ))] + Some(delay_eof) => match delay_eof {}, + None => self.poll_inner(cx), + } + } + + #[cfg(feature = "ffi")] + pub(crate) fn as_ffi_mut(&mut self) -> &mut crate::ffi::UserBody { + match self.kind { + Kind::Ffi(ref mut body) => return body, + _ => { + self.kind = Kind::Ffi(crate::ffi::UserBody::new()); + } + } + + match self.kind { + Kind::Ffi(ref mut body) => body, + _ => unreachable!(), + } + } + + fn poll_inner(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<crate::Result<Bytes>>> { + match self.kind { + Kind::Once(ref mut val) => Poll::Ready(val.take().map(Ok)), + Kind::Chan { + content_length: ref mut len, + ref mut data_rx, + ref mut want_tx, + .. + } => { + want_tx.send(WANT_READY); + + match ready!(Pin::new(data_rx).poll_next(cx)?) { + Some(chunk) => { + len.sub_if(chunk.len() as u64); + Poll::Ready(Some(Ok(chunk))) + } + None => Poll::Ready(None), + } + } + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + Kind::H2 { + ref ping, + recv: ref mut h2, + content_length: ref mut len, + } => match ready!(h2.poll_data(cx)) { + Some(Ok(bytes)) => { + let _ = h2.flow_control().release_capacity(bytes.len()); + len.sub_if(bytes.len() as u64); + ping.record_data(bytes.len()); + Poll::Ready(Some(Ok(bytes))) + } + Some(Err(e)) => Poll::Ready(Some(Err(crate::Error::new_body(e)))), + None => Poll::Ready(None), + }, + + #[cfg(feature = "ffi")] + Kind::Ffi(ref mut body) => body.poll_data(cx), + + #[cfg(feature = "stream")] + Kind::Wrapped(ref mut s) => match ready!(s.get_mut().as_mut().poll_next(cx)) { + Some(res) => Poll::Ready(Some(res.map_err(crate::Error::new_body))), + None => Poll::Ready(None), + }, + } + } + + #[cfg(feature = "http1")] + pub(super) fn take_full_data(&mut self) -> Option<Bytes> { + if let Kind::Once(ref mut chunk) = self.kind { + chunk.take() + } else { + None + } + } +} + +impl Default for Body { + /// Returns `Body::empty()`. + #[inline] + fn default() -> Body { + Body::empty() + } +} + +impl HttpBody for Body { + type Data = Bytes; + type Error = crate::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Data, Self::Error>>> { + self.poll_eof(cx) + } + + fn poll_trailers( + #[cfg_attr(not(feature = "http2"), allow(unused_mut))] mut self: Pin<&mut Self>, + #[cfg_attr(not(feature = "http2"), allow(unused))] cx: &mut task::Context<'_>, + ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { + match self.kind { + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + Kind::H2 { + recv: ref mut h2, + ref ping, + .. + } => match ready!(h2.poll_trailers(cx)) { + Ok(t) => { + ping.record_non_data(); + Poll::Ready(Ok(t)) + } + Err(e) => Poll::Ready(Err(crate::Error::new_h2(e))), + }, + Kind::Chan { + ref mut trailers_rx, + .. + } => match ready!(Pin::new(trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Ok(Some(t))), + Err(_) => Poll::Ready(Ok(None)), + }, + #[cfg(feature = "ffi")] + Kind::Ffi(ref mut body) => body.poll_trailers(cx), + _ => Poll::Ready(Ok(None)), + } + } + + fn is_end_stream(&self) -> bool { + match self.kind { + Kind::Once(ref val) => val.is_none(), + Kind::Chan { content_length, .. } => content_length == DecodedLength::ZERO, + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(), + #[cfg(feature = "ffi")] + Kind::Ffi(..) => false, + #[cfg(feature = "stream")] + Kind::Wrapped(..) => false, + } + } + + fn size_hint(&self) -> SizeHint { + macro_rules! opt_len { + ($content_length:expr) => {{ + let mut hint = SizeHint::default(); + + if let Some(content_length) = $content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + }}; + } + + match self.kind { + Kind::Once(Some(ref val)) => SizeHint::with_exact(val.len() as u64), + Kind::Once(None) => SizeHint::with_exact(0), + #[cfg(feature = "stream")] + Kind::Wrapped(..) => SizeHint::default(), + Kind::Chan { content_length, .. } => opt_len!(content_length), + #[cfg(all(feature = "http2", any(feature = "client", feature = "server")))] + Kind::H2 { content_length, .. } => opt_len!(content_length), + #[cfg(feature = "ffi")] + Kind::Ffi(..) => SizeHint::default(), + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Streaming; + #[derive(Debug)] + struct Empty; + #[derive(Debug)] + struct Full<'a>(&'a Bytes); + + let mut builder = f.debug_tuple("Body"); + match self.kind { + Kind::Once(None) => builder.field(&Empty), + Kind::Once(Some(ref chunk)) => builder.field(&Full(chunk)), + _ => builder.field(&Streaming), + }; + + builder.finish() + } +} + +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +impl Stream for Body { + type Item = crate::Result<Bytes>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { + HttpBody::poll_data(self, cx) + } +} + +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +impl From<Box<dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send>> for Body { + #[inline] + fn from( + stream: Box<dyn Stream<Item = Result<Bytes, Box<dyn StdError + Send + Sync>>> + Send>, + ) -> Body { + Body::new(Kind::Wrapped(SyncWrapper::new(stream.into()))) + } +} + +impl From<Bytes> for Body { + #[inline] + fn from(chunk: Bytes) -> Body { + if chunk.is_empty() { + Body::empty() + } else { + Body::new(Kind::Once(Some(chunk))) + } + } +} + +impl From<Vec<u8>> for Body { + #[inline] + fn from(vec: Vec<u8>) -> Body { + Body::from(Bytes::from(vec)) + } +} + +impl From<&'static [u8]> for Body { + #[inline] + fn from(slice: &'static [u8]) -> Body { + Body::from(Bytes::from(slice)) + } +} + +impl From<Cow<'static, [u8]>> for Body { + #[inline] + fn from(cow: Cow<'static, [u8]>) -> Body { + match cow { + Cow::Borrowed(b) => Body::from(b), + Cow::Owned(o) => Body::from(o), + } + } +} + +impl From<String> for Body { + #[inline] + fn from(s: String) -> Body { + Body::from(Bytes::from(s.into_bytes())) + } +} + +impl From<&'static str> for Body { + #[inline] + fn from(slice: &'static str) -> Body { + Body::from(Bytes::from(slice.as_bytes())) + } +} + +impl From<Cow<'static, str>> for Body { + #[inline] + fn from(cow: Cow<'static, str>) -> Body { + match cow { + Cow::Borrowed(b) => Body::from(b), + Cow::Owned(o) => Body::from(o), + } + } +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self.data_tx + .poll_ready(cx) + .map_err(|_| crate::Error::new_closed()) + } + + fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> crate::Result<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> { + self.ready().await?; + self.data_tx + .try_send(Ok(chunk)) + .map_err(|_| crate::Error::new_closed()) + } + + /// Send trailers on trailers channel. + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> crate::Result<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(crate::Error::new_closed()), + }; + tx.send(trailers).map_err(|_| crate::Error::new_closed()) + } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + pub fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self.data_tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + /// Aborts the body in an abnormal fashion. + pub fn abort(self) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(crate::Error::new_body_write_aborted())); + } + + #[cfg(feature = "http1")] + pub(crate) fn send_error(&mut self, err: crate::Error) { + let _ = self.data_tx.try_send(Err(err)); + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Open; + #[derive(Debug)] + struct Closed; + + let mut builder = f.debug_tuple("Sender"); + match self.want_rx.peek() { + watch::CLOSED => builder.field(&Closed), + _ => builder.field(&Open), + }; + + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use std::mem; + use std::task::Poll; + + use super::{Body, DecodedLength, HttpBody, Sender, SizeHint}; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + + let body_size = mem::size_of::<Body>(); + let body_expected_size = mem::size_of::<u64>() * 6; + assert!( + body_size <= body_expected_size, + "Body size = {} <= {}", + body_size, + body_expected_size, + ); + + assert_eq!(body_size, mem::size_of::<Option<Body>>(), "Option<Body>"); + + assert_eq!( + mem::size_of::<Sender>(), + mem::size_of::<usize>() * 5, + "Sender" + ); + + assert_eq!( + mem::size_of::<Sender>(), + mem::size_of::<Option<Sender>>(), + "Option<Sender>" + ); + } + + #[test] + fn size_hint() { + fn eq(body: Body, b: SizeHint, note: &str) { + let a = body.size_hint(); + assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); + assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); + } + + eq(Body::from("Hello"), SizeHint::with_exact(5), "from str"); + + eq(Body::empty(), SizeHint::with_exact(0), "empty"); + + eq(Body::channel().1, SizeHint::new(), "channel"); + + eq( + Body::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, + SizeHint::with_exact(4), + "channel with length", + ); + } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = Body::channel(); + + tx.abort(); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.data().await.expect("item 1").expect("chunk 1"); + assert_eq!(chunk1, "chunk 1"); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = Body::channel(); + + assert!(rx.data().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.data()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!( + tx_ready.poll().is_ready(), + "tx is ready after rx has been polled" + ); + } + + #[test] + fn channel_notices_closure() { + let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(ref e)) if e.is_closed() => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } +} diff --git a/third_party/rust/hyper/src/body/length.rs b/third_party/rust/hyper/src/body/length.rs new file mode 100644 index 0000000000..e2bbee8039 --- /dev/null +++ b/third_party/rust/hyper/src/body/length.rs @@ -0,0 +1,123 @@ +use std::fmt; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) struct DecodedLength(u64); + +#[cfg(any(feature = "http1", feature = "http2"))] +impl From<Option<u64>> for DecodedLength { + fn from(len: Option<u64>) -> Self { + len.and_then(|len| { + // If the length is u64::MAX, oh well, just reported chunked. + Self::checked_new(len).ok() + }) + .unwrap_or(DecodedLength::CHUNKED) + } +} + +#[cfg(any(feature = "http1", feature = "http2", test))] +const MAX_LEN: u64 = std::u64::MAX - 2; + +impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + #[cfg(test)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + + /// Takes the length as a content-length without other checks. + /// + /// Should only be called if previously confirmed this isn't + /// CLOSE_DELIMITED or CHUNKED. + #[inline] + #[cfg(feature = "http1")] + pub(crate) fn danger_len(self) -> u64 { + debug_assert!(self.0 < Self::CHUNKED.0); + self.0 + } + + /// Converts to an Option<u64> representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option<u64> { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } + + /// Checks the `u64` is within the maximum allowed for content-length. + #[cfg(any(feature = "http1", feature = "http2"))] + pub(crate) fn checked_new(len: u64) -> Result<Self, crate::error::Parse> { + use tracing::warn; + + if len <= MAX_LEN { + Ok(DecodedLength(len)) + } else { + warn!("content-length bigger than maximum: {} > {}", len, MAX_LEN); + Err(crate::error::Parse::TooLarge) + } + } + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + + /// Returns whether this represents an exact length. + /// + /// This includes 0, which of course is an exact known length. + /// + /// It would return false if "chunked" or otherwise size-unknown. + #[cfg(feature = "http2")] + pub(crate) fn is_exact(&self) -> bool { + self.0 <= MAX_LEN + } +} + +impl fmt::Debug for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("CLOSE_DELIMITED"), + DecodedLength::CHUNKED => f.write_str("CHUNKED"), + DecodedLength(n) => f.debug_tuple("DecodedLength").field(&n).finish(), + } + } +} + +impl fmt::Display for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("close-delimited"), + DecodedLength::CHUNKED => f.write_str("chunked encoding"), + DecodedLength::ZERO => f.write_str("empty"), + DecodedLength(n) => write!(f, "content-length ({} bytes)", n), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sub_if_known() { + let mut len = DecodedLength::new(30); + len.sub_if(20); + + assert_eq!(len.0, 10); + } + + #[test] + fn sub_if_chunked() { + let mut len = DecodedLength::CHUNKED; + len.sub_if(20); + + assert_eq!(len, DecodedLength::CHUNKED); + } +} diff --git a/third_party/rust/hyper/src/body/mod.rs b/third_party/rust/hyper/src/body/mod.rs new file mode 100644 index 0000000000..5e2181e941 --- /dev/null +++ b/third_party/rust/hyper/src/body/mod.rs @@ -0,0 +1,65 @@ +//! Streaming bodies for Requests and Responses +//! +//! For both [Clients](crate::client) and [Servers](crate::server), requests and +//! responses use streaming bodies, instead of complete buffering. This +//! allows applications to not use memory they don't need, and allows exerting +//! back-pressure on connections by only reading when asked. +//! +//! There are two pieces to this in hyper: +//! +//! - **The [`HttpBody`](HttpBody) trait** describes all possible bodies. +//! hyper allows any body type that implements `HttpBody`, allowing +//! applications to have fine-grained control over their streaming. +//! - **The [`Body`](Body) concrete type**, which is an implementation of +//! `HttpBody`, and returned by hyper as a "receive stream" (so, for server +//! requests and client responses). It is also a decent default implementation +//! if you don't have very custom needs of your send streams. + +pub use bytes::{Buf, Bytes}; +pub use http_body::Body as HttpBody; +pub use http_body::SizeHint; + +pub use self::aggregate::aggregate; +pub use self::body::{Body, Sender}; +pub(crate) use self::length::DecodedLength; +pub use self::to_bytes::to_bytes; + +mod aggregate; +mod body; +mod length; +mod to_bytes; + +/// An optimization to try to take a full body if immediately available. +/// +/// This is currently limited to *only* `hyper::Body`s. +#[cfg(feature = "http1")] +pub(crate) fn take_full_data<T: HttpBody + 'static>(body: &mut T) -> Option<T::Data> { + use std::any::{Any, TypeId}; + + // This static type check can be optimized at compile-time. + if TypeId::of::<T>() == TypeId::of::<Body>() { + let mut full = (body as &mut dyn Any) + .downcast_mut::<Body>() + .expect("must be Body") + .take_full_data(); + // This second cast is required to make the type system happy. + // Without it, the compiler cannot reason that the type is actually + // `T::Data`. Oh wells. + // + // It's still a measurable win! + (&mut full as &mut dyn Any) + .downcast_mut::<Option<T::Data>>() + .expect("must be T::Data") + .take() + } else { + None + } +} + +fn _assert_send_sync() { + fn _assert_send<T: Send>() {} + fn _assert_sync<T: Sync>() {} + + _assert_send::<Body>(); + _assert_sync::<Body>(); +} diff --git a/third_party/rust/hyper/src/body/to_bytes.rs b/third_party/rust/hyper/src/body/to_bytes.rs new file mode 100644 index 0000000000..038c6fd0f3 --- /dev/null +++ b/third_party/rust/hyper/src/body/to_bytes.rs @@ -0,0 +1,82 @@ +use bytes::{Buf, BufMut, Bytes}; + +use super::HttpBody; + +/// Concatenate the buffers from a body into a single `Bytes` asynchronously. +/// +/// This may require copying the data into a single buffer. If you don't need +/// a contiguous buffer, prefer the [`aggregate`](crate::body::aggregate()) +/// function. +/// +/// # Note +/// +/// Care needs to be taken if the remote is untrusted. The function doesn't implement any length +/// checks and an malicious peer might make it consume arbitrary amounts of memory. Checking the +/// `Content-Length` is a possibility, but it is not strictly mandated to be present. +/// +/// # Example +/// +/// ``` +/// # #[cfg(all(feature = "client", feature = "tcp", any(feature = "http1", feature = "http2")))] +/// # async fn doc() -> hyper::Result<()> { +/// use hyper::{body::HttpBody}; +/// +/// # let request = hyper::Request::builder() +/// # .method(hyper::Method::POST) +/// # .uri("http://httpbin.org/post") +/// # .header("content-type", "application/json") +/// # .body(hyper::Body::from(r#"{"library":"hyper"}"#)).unwrap(); +/// # let client = hyper::Client::new(); +/// let response = client.request(request).await?; +/// +/// const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024; +/// +/// let response_content_length = match response.body().size_hint().upper() { +/// Some(v) => v, +/// None => MAX_ALLOWED_RESPONSE_SIZE + 1 // Just to protect ourselves from a malicious response +/// }; +/// +/// if response_content_length < MAX_ALLOWED_RESPONSE_SIZE { +/// let body_bytes = hyper::body::to_bytes(response.into_body()).await?; +/// println!("body: {:?}", body_bytes); +/// } +/// +/// # Ok(()) +/// # } +/// ``` +pub async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> +where + T: HttpBody, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.copy_to_bytes(first.remaining())); + }; + + // Don't pre-emptively reserve *too* much. + let rest = (body.size_hint().lower() as usize).min(1024 * 16); + let cap = first + .remaining() + .saturating_add(second.remaining()) + .saturating_add(rest); + // With more than 1 buf, we gotta flatten into a Vec first. + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} diff --git a/third_party/rust/hyper/src/cfg.rs b/third_party/rust/hyper/src/cfg.rs new file mode 100644 index 0000000000..71a5351d21 --- /dev/null +++ b/third_party/rust/hyper/src/cfg.rs @@ -0,0 +1,44 @@ +macro_rules! cfg_feature { + ( + #![$meta:meta] + $($item:item)* + ) => { + $( + #[cfg($meta)] + #[cfg_attr(docsrs, doc(cfg($meta)))] + $item + )* + } +} + +macro_rules! cfg_proto { + ($($item:item)*) => { + cfg_feature! { + #![all( + any(feature = "http1", feature = "http2"), + any(feature = "client", feature = "server"), + )] + $($item)* + } + } +} + +cfg_proto! { + macro_rules! cfg_client { + ($($item:item)*) => { + cfg_feature! { + #![feature = "client"] + $($item)* + } + } + } + + macro_rules! cfg_server { + ($($item:item)*) => { + cfg_feature! { + #![feature = "server"] + $($item)* + } + } + } +} diff --git a/third_party/rust/hyper/src/client/client.rs b/third_party/rust/hyper/src/client/client.rs new file mode 100644 index 0000000000..4425e25899 --- /dev/null +++ b/third_party/rust/hyper/src/client/client.rs @@ -0,0 +1,1495 @@ +use std::error::Error as StdError; +use std::fmt; +use std::mem; +use std::time::Duration; + +use futures_channel::oneshot; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use http::header::{HeaderValue, HOST}; +use http::uri::{Port, Scheme}; +use http::{Method, Request, Response, Uri, Version}; +use tracing::{debug, trace, warn}; + +use super::conn; +use super::connect::{self, sealed::Connect, Alpn, Connected, Connection}; +use super::pool::{ + self, CheckoutIsClosedError, Key as PoolKey, Pool, Poolable, Pooled, Reservation, +}; +#[cfg(feature = "tcp")] +use super::HttpConnector; +use crate::body::{Body, HttpBody}; +use crate::common::{exec::BoxSendFuture, sync_wrapper::SyncWrapper, lazy as hyper_lazy, task, Future, Lazy, Pin, Poll}; +use crate::rt::Executor; + +/// A Client to make outgoing HTTP requests. +/// +/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The +/// underlying connection pool will be reused. +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +pub struct Client<C, B = Body> { + config: Config, + conn_builder: conn::Builder, + connector: C, + pool: Pool<PoolClient<B>>, +} + +#[derive(Clone, Copy, Debug)] +struct Config { + retry_canceled_requests: bool, + set_host: bool, + ver: Ver, +} + +/// A `Future` that will resolve to an HTTP Response. +/// +/// This is returned by `Client::request` (and `Client::get`). +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + inner: SyncWrapper<Pin<Box<dyn Future<Output = crate::Result<Response<Body>>> + Send>>>, +} + +// ===== impl Client ===== + +#[cfg(feature = "tcp")] +impl Client<HttpConnector, Body> { + /// Create a new Client with the default [config](Builder). + /// + /// # Note + /// + /// The default connector does **not** handle TLS. Speaking to `https` + /// destinations will require [configuring a connector that implements + /// TLS](https://hyper.rs/guides/client/configuration). + #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))] + #[inline] + pub fn new() -> Client<HttpConnector, Body> { + Builder::default().build_http() + } +} + +#[cfg(feature = "tcp")] +impl Default for Client<HttpConnector, Body> { + fn default() -> Client<HttpConnector, Body> { + Client::new() + } +} + +impl Client<(), Body> { + /// Create a builder to configure a new `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use std::time::Duration; + /// use hyper::Client; + /// + /// let client = Client::builder() + /// .pool_idle_timeout(Duration::from_secs(30)) + /// .http2_only(true) + /// .build_http(); + /// # let infer: Client<_, hyper::Body> = client; + /// # drop(infer); + /// # } + /// # fn main() {} + /// ``` + #[inline] + pub fn builder() -> Builder { + Builder::default() + } +} + +impl<C, B> Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + /// Send a `GET` request to the supplied `Uri`. + /// + /// # Note + /// + /// This requires that the `HttpBody` type have a `Default` implementation. + /// It *should* return an "empty" version of itself, such that + /// `HttpBody::is_end_stream` is `true`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use hyper::{Client, Uri}; + /// + /// let client = Client::new(); + /// + /// let future = client.get(Uri::from_static("http://httpbin.org/ip")); + /// # } + /// # fn main() {} + /// ``` + pub fn get(&self, uri: Uri) -> ResponseFuture + where + B: Default, + { + let body = B::default(); + if !body.is_end_stream() { + warn!("default HttpBody used for get() does not return true for is_end_stream"); + } + + let mut req = Request::new(body); + *req.uri_mut() = uri; + self.request(req) + } + + /// Send a constructed `Request` using this `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "runtime")] + /// # fn run () { + /// use hyper::{Body, Method, Client, Request}; + /// + /// let client = Client::new(); + /// + /// let req = Request::builder() + /// .method(Method::POST) + /// .uri("http://httpbin.org/post") + /// .body(Body::from("Hallo!")) + /// .expect("request builder"); + /// + /// let future = client.request(req); + /// # } + /// # fn main() {} + /// ``` + pub fn request(&self, mut req: Request<B>) -> ResponseFuture { + let is_http_connect = req.method() == Method::CONNECT; + match req.version() { + Version::HTTP_11 => (), + Version::HTTP_10 => { + if is_http_connect { + warn!("CONNECT is not allowed for HTTP/1.0"); + return ResponseFuture::new(future::err( + crate::Error::new_user_unsupported_request_method(), + )); + } + } + Version::HTTP_2 => (), + // completely unsupported HTTP version (like HTTP/0.9)! + other => return ResponseFuture::error_version(other), + }; + + let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { + Ok(s) => s, + Err(err) => { + return ResponseFuture::new(future::err(err)); + } + }; + + ResponseFuture::new(self.clone().retryably_send_request(req, pool_key)) + } + + async fn retryably_send_request( + self, + mut req: Request<B>, + pool_key: PoolKey, + ) -> crate::Result<Response<Body>> { + let uri = req.uri().clone(); + + loop { + req = match self.send_request(req, pool_key.clone()).await { + Ok(resp) => return Ok(resp), + Err(ClientError::Normal(err)) => return Err(err), + Err(ClientError::Canceled { + connection_reused, + mut req, + reason, + }) => { + if !self.config.retry_canceled_requests || !connection_reused { + // if client disabled, don't retry + // a fresh connection means we definitely can't retry + return Err(reason); + } + + trace!( + "unstarted request canceled, trying again (reason={:?})", + reason + ); + *req.uri_mut() = uri.clone(); + req + } + } + } + } + + async fn send_request( + &self, + mut req: Request<B>, + pool_key: PoolKey, + ) -> Result<Response<Body>, ClientError<B>> { + let mut pooled = match self.connection_for(pool_key).await { + Ok(pooled) => pooled, + Err(ClientConnectError::Normal(err)) => return Err(ClientError::Normal(err)), + Err(ClientConnectError::H2CheckoutIsClosed(reason)) => { + return Err(ClientError::Canceled { + connection_reused: true, + req, + reason, + }) + } + }; + + if pooled.is_http1() { + if req.version() == Version::HTTP_2 { + warn!("Connection is HTTP/1, but request requires HTTP/2"); + return Err(ClientError::Normal( + crate::Error::new_user_unsupported_version(), + )); + } + + if self.config.set_host { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + } + + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } else if pooled.conn_info.is_proxied { + absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + } + } else if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } + + let fut = pooled + .send_request_retryable(req) + .map_err(ClientError::map_with_reused(pooled.is_reused())); + + // If the Connector included 'extra' info, add to Response... + let extra_info = pooled.conn_info.extra.clone(); + let fut = fut.map_ok(move |mut res| { + if let Some(extra) = extra_info { + extra.set(res.extensions_mut()); + } + res + }); + + // As of futures@0.1.21, there is a race condition in the mpsc + // channel, such that sending when the receiver is closing can + // result in the message being stuck inside the queue. It won't + // ever notify until the Sender side is dropped. + // + // To counteract this, we must check if our senders 'want' channel + // has been closed after having tried to send. If so, error out... + if pooled.is_closed() { + return fut.await; + } + + let mut res = fut.await?; + + // If pooled is HTTP/2, we can toss this reference immediately. + // + // when pooled is dropped, it will try to insert back into the + // pool. To delay that, spawn a future that completes once the + // sender is ready again. + // + // This *should* only be once the related `Connection` has polled + // for a new request to start. + // + // It won't be ready if there is a body to stream. + if pooled.is_http2() || !pooled.is_pool_enabled() || pooled.is_ready() { + drop(pooled); + } else if !res.body().is_end_stream() { + let (delayed_tx, delayed_rx) = oneshot::channel(); + res.body_mut().delayed_eof(delayed_rx); + let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(move |_| { + // At this point, `pooled` is dropped, and had a chance + // to insert into the pool (if conn was idle) + drop(delayed_tx); + }); + + self.conn_builder.exec.execute(on_idle); + } else { + // There's no body to delay, but the connection isn't + // ready yet. Only re-insert when it's ready + let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(|_| ()); + + self.conn_builder.exec.execute(on_idle); + } + + Ok(res) + } + + async fn connection_for( + &self, + pool_key: PoolKey, + ) -> Result<Pooled<PoolClient<B>>, ClientConnectError> { + // This actually races 2 different futures to try to get a ready + // connection the fastest, and to reduce connection churn. + // + // - If the pool has an idle connection waiting, that's used + // immediately. + // - Otherwise, the Connector is asked to start connecting to + // the destination Uri. + // - Meanwhile, the pool Checkout is watching to see if any other + // request finishes and tries to insert an idle connection. + // - If a new connection is started, but the Checkout wins after + // (an idle connection became available first), the started + // connection future is spawned into the runtime to complete, + // and then be inserted into the pool as an idle connection. + let checkout = self.pool.checkout(pool_key.clone()); + let connect = self.connect_to(pool_key); + let is_ver_h2 = self.config.ver == Ver::Http2; + + // The order of the `select` is depended on below... + + match future::select(checkout, connect).await { + // Checkout won, connect future may have been started or not. + // + // If it has, let it finish and insert back into the pool, + // so as to not waste the socket... + Either::Left((Ok(checked_out), connecting)) => { + // This depends on the `select` above having the correct + // order, such that if the checkout future were ready + // immediately, the connect future will never have been + // started. + // + // If it *wasn't* ready yet, then the connect future will + // have been started... + if connecting.started() { + let bg = connecting + .map_err(|err| { + trace!("background connect error: {}", err); + }) + .map(|_pooled| { + // dropping here should just place it in + // the Pool for us... + }); + // An execute error here isn't important, we're just trying + // to prevent a waste of a socket... + self.conn_builder.exec.execute(bg); + } + Ok(checked_out) + } + // Connect won, checkout can just be dropped. + Either::Right((Ok(connected), _checkout)) => Ok(connected), + // Either checkout or connect could get canceled: + // + // 1. Connect is canceled if this is HTTP/2 and there is + // an outstanding HTTP/2 connecting task. + // 2. Checkout is canceled if the pool cannot deliver an + // idle connection reliably. + // + // In both cases, we should just wait for the other future. + Either::Left((Err(err), connecting)) => { + if err.is_canceled() { + connecting.await.map_err(ClientConnectError::Normal) + } else { + Err(ClientConnectError::Normal(err)) + } + } + Either::Right((Err(err), checkout)) => { + if err.is_canceled() { + checkout.await.map_err(move |err| { + if is_ver_h2 + && err.is_canceled() + && err.find_source::<CheckoutIsClosedError>().is_some() + { + ClientConnectError::H2CheckoutIsClosed(err) + } else { + ClientConnectError::Normal(err) + } + }) + } else { + Err(ClientConnectError::Normal(err)) + } + } + } + } + + fn connect_to( + &self, + pool_key: PoolKey, + ) -> impl Lazy<Output = crate::Result<Pooled<PoolClient<B>>>> + Unpin { + let executor = self.conn_builder.exec.clone(); + let pool = self.pool.clone(); + #[cfg(not(feature = "http2"))] + let conn_builder = self.conn_builder.clone(); + #[cfg(feature = "http2")] + let mut conn_builder = self.conn_builder.clone(); + let ver = self.config.ver; + let is_ver_h2 = ver == Ver::Http2; + let connector = self.connector.clone(); + let dst = domain_as_uri(pool_key.clone()); + hyper_lazy(move || { + // Try to take a "connecting lock". + // + // If the pool_key is for HTTP/2, and there is already a + // connection being established, then this can't take a + // second lock. The "connect_to" future is Canceled. + let connecting = match pool.connecting(&pool_key, ver) { + Some(lock) => lock, + None => { + let canceled = + crate::Error::new_canceled().with("HTTP/2 connection in progress"); + return Either::Right(future::err(canceled)); + } + }; + Either::Left( + connector + .connect(connect::sealed::Internal, dst) + .map_err(crate::Error::new_connect) + .and_then(move |io| { + let connected = io.connected(); + // If ALPN is h2 and we aren't http2_only already, + // then we need to convert our pool checkout into + // a single HTTP2 one. + let connecting = if connected.alpn == Alpn::H2 && !is_ver_h2 { + match connecting.alpn_h2(&pool) { + Some(lock) => { + trace!("ALPN negotiated h2, updating pool"); + lock + } + None => { + // Another connection has already upgraded, + // the pool checkout should finish up for us. + let canceled = crate::Error::new_canceled() + .with("ALPN upgraded to HTTP/2"); + return Either::Right(future::err(canceled)); + } + } + } else { + connecting + }; + + #[cfg_attr(not(feature = "http2"), allow(unused))] + let is_h2 = is_ver_h2 || connected.alpn == Alpn::H2; + #[cfg(feature = "http2")] + { + conn_builder.http2_only(is_h2); + } + + Either::Left(Box::pin(async move { + let (tx, conn) = conn_builder.handshake(io).await?; + + trace!("handshake complete, spawning background dispatcher task"); + executor.execute( + conn.map_err(|e| debug!("client connection error: {}", e)) + .map(|_| ()), + ); + + // Wait for 'conn' to ready up before we + // declare this tx as usable + let tx = tx.when_ready().await?; + + let tx = { + #[cfg(feature = "http2")] + { + if is_h2 { + PoolTx::Http2(tx.into_http2()) + } else { + PoolTx::Http1(tx) + } + } + #[cfg(not(feature = "http2"))] + PoolTx::Http1(tx) + }; + + Ok(pool.pooled( + connecting, + PoolClient { + conn_info: connected, + tx, + }, + )) + })) + }), + ) + }) + } +} + +impl<C, B> tower_service::Service<Request<B>> for Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = Response<Body>; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.request(req) + } +} + +impl<C, B> tower_service::Service<Request<B>> for &'_ Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = Response<Body>; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.request(req) + } +} + +impl<C: Clone, B> Clone for Client<C, B> { + fn clone(&self) -> Client<C, B> { + Client { + config: self.config.clone(), + conn_builder: self.conn_builder.clone(), + connector: self.connector.clone(), + pool: self.pool.clone(), + } + } +} + +impl<C, B> fmt::Debug for Client<C, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} + +// ===== impl ResponseFuture ===== + +impl ResponseFuture { + fn new<F>(value: F) -> Self + where + F: Future<Output = crate::Result<Response<Body>>> + Send + 'static, + { + Self { + inner: SyncWrapper::new(Box::pin(value)) + } + } + + fn error_version(ver: Version) -> Self { + warn!("Request has unsupported version \"{:?}\"", ver); + ResponseFuture::new(Box::pin(future::err( + crate::Error::new_user_unsupported_version(), + ))) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Future<Response>") + } +} + +impl Future for ResponseFuture { + type Output = crate::Result<Response<Body>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.inner.get_mut().as_mut().poll(cx) + } +} + +// ===== impl PoolClient ===== + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +struct PoolClient<B> { + conn_info: Connected, + tx: PoolTx<B>, +} + +enum PoolTx<B> { + Http1(conn::SendRequest<B>), + #[cfg(feature = "http2")] + Http2(conn::Http2SendRequest<B>), +} + +impl<B> PoolClient<B> { + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match self.tx { + PoolTx::Http1(ref mut tx) => tx.poll_ready(cx), + #[cfg(feature = "http2")] + PoolTx::Http2(_) => Poll::Ready(Ok(())), + } + } + + fn is_http1(&self) -> bool { + !self.is_http2() + } + + fn is_http2(&self) -> bool { + match self.tx { + PoolTx::Http1(_) => false, + #[cfg(feature = "http2")] + PoolTx::Http2(_) => true, + } + } + + fn is_ready(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_ready(), + #[cfg(feature = "http2")] + PoolTx::Http2(ref tx) => tx.is_ready(), + } + } + + fn is_closed(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_closed(), + #[cfg(feature = "http2")] + PoolTx::Http2(ref tx) => tx.is_closed(), + } + } +} + +impl<B: HttpBody + 'static> PoolClient<B> { + fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + where + B: Send, + { + match self.tx { + #[cfg(not(feature = "http2"))] + PoolTx::Http1(ref mut tx) => tx.send_request_retryable(req), + #[cfg(feature = "http2")] + PoolTx::Http1(ref mut tx) => Either::Left(tx.send_request_retryable(req)), + #[cfg(feature = "http2")] + PoolTx::Http2(ref mut tx) => Either::Right(tx.send_request_retryable(req)), + } + } +} + +impl<B> Poolable for PoolClient<B> +where + B: Send + 'static, +{ + fn is_open(&self) -> bool { + match self.tx { + PoolTx::Http1(ref tx) => tx.is_ready(), + #[cfg(feature = "http2")] + PoolTx::Http2(ref tx) => tx.is_ready(), + } + } + + fn reserve(self) -> Reservation<Self> { + match self.tx { + PoolTx::Http1(tx) => Reservation::Unique(PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http1(tx), + }), + #[cfg(feature = "http2")] + PoolTx::Http2(tx) => { + let b = PoolClient { + conn_info: self.conn_info.clone(), + tx: PoolTx::Http2(tx.clone()), + }; + let a = PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http2(tx), + }; + Reservation::Shared(a, b) + } + } + } + + fn can_share(&self) -> bool { + self.is_http2() + } +} + +// ===== impl ClientError ===== + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +enum ClientError<B> { + Normal(crate::Error), + Canceled { + connection_reused: bool, + req: Request<B>, + reason: crate::Error, + }, +} + +impl<B> ClientError<B> { + fn map_with_reused(conn_reused: bool) -> impl Fn((crate::Error, Option<Request<B>>)) -> Self { + move |(err, orig_req)| { + if let Some(req) = orig_req { + ClientError::Canceled { + connection_reused: conn_reused, + reason: err, + req, + } + } else { + ClientError::Normal(err) + } + } + } +} + +enum ClientConnectError { + Normal(crate::Error), + H2CheckoutIsClosed(crate::Error), +} + +/// A marker to identify what version a pooled connection is. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub(super) enum Ver { + Auto, + Http2, +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} + +fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<PoolKey> { + let uri_clone = uri.clone(); + match (uri_clone.scheme(), uri_clone.authority()) { + (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), + (None, Some(auth)) if is_http_connect => { + let scheme = match auth.port_u16() { + Some(443) => { + set_scheme(uri, Scheme::HTTPS); + Scheme::HTTPS + } + _ => { + set_scheme(uri, Scheme::HTTP); + Scheme::HTTP + } + }; + Ok((scheme, auth.clone())) + } + _ => { + debug!("Client requires absolute-form URIs, received: {:?}", uri); + Err(crate::Error::new_user_absolute_uri_required()) + } + } +} + +fn domain_as_uri((scheme, auth): PoolKey) -> Uri { + http::uri::Builder::new() + .scheme(scheme) + .authority(auth) + .path_and_query("/") + .build() + .expect("domain is valid Uri") +} + +fn set_scheme(uri: &mut Uri, scheme: Scheme) { + debug_assert!( + uri.scheme().is_none(), + "set_scheme expects no existing scheme" + ); + let old = mem::replace(uri, Uri::default()); + let mut parts: ::http::uri::Parts = old.into(); + parts.scheme = Some(scheme); + parts.path_and_query = Some("/".parse().expect("slash is a valid path")); + *uri = Uri::from_parts(parts).expect("scheme is valid"); +} + +fn get_non_default_port(uri: &Uri) -> Option<Port<&str>> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +} + +/// A builder to configure a new [`Client`](Client). +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "runtime")] +/// # fn run () { +/// use std::time::Duration; +/// use hyper::Client; +/// +/// let client = Client::builder() +/// .pool_idle_timeout(Duration::from_secs(30)) +/// .http2_only(true) +/// .build_http(); +/// # let infer: Client<_, hyper::Body> = client; +/// # drop(infer); +/// # } +/// # fn main() {} +/// ``` +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +#[derive(Clone)] +pub struct Builder { + client_config: Config, + conn_builder: conn::Builder, + pool_config: pool::Config, +} + +impl Default for Builder { + fn default() -> Self { + Self { + client_config: Config { + retry_canceled_requests: true, + set_host: true, + ver: Ver::Auto, + }, + conn_builder: conn::Builder::new(), + pool_config: pool::Config { + idle_timeout: Some(Duration::from_secs(90)), + max_idle_per_host: std::usize::MAX, + }, + } + } +} + +impl Builder { + #[doc(hidden)] + #[deprecated( + note = "name is confusing, to disable the connection pool, call pool_max_idle_per_host(0)" + )] + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + if !val { + // disable + self.pool_max_idle_per_host(0) + } else if self.pool_config.max_idle_per_host == 0 { + // enable + self.pool_max_idle_per_host(std::usize::MAX) + } else { + // already enabled + self + } + } + + #[doc(hidden)] + #[deprecated(note = "renamed to `pool_idle_timeout`")] + pub fn keep_alive_timeout<D>(&mut self, val: D) -> &mut Self + where + D: Into<Option<Duration>>, + { + self.pool_idle_timeout(val) + } + + /// Set an optional timeout for idle sockets being kept-alive. + /// + /// Pass `None` to disable timeout. + /// + /// Default is 90 seconds. + pub fn pool_idle_timeout<D>(&mut self, val: D) -> &mut Self + where + D: Into<Option<Duration>>, + { + self.pool_config.idle_timeout = val.into(); + self + } + + #[doc(hidden)] + #[deprecated(note = "renamed to `pool_max_idle_per_host`")] + pub fn max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + /// Sets the maximum idle connection per host allowed in the pool. + /// + /// Default is `usize::MAX` (no limit). + pub fn pool_max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + // HTTP/1 options + + /// Sets the exact size of the read buffer to *always* use. + /// + /// Note that setting this option unsets the `http1_max_buf_size` option. + /// + /// Default is an adaptive read buffer. + pub fn http1_read_buf_exact_size(&mut self, sz: usize) -> &mut Self { + self.conn_builder.http1_read_buf_exact_size(Some(sz)); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// Note that setting this option unsets the `http1_read_exact_buf_size` option. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_buf_size(&mut self, max: usize) -> &mut Self { + self.conn_builder.http1_max_buf_size(max); + self + } + + /// Set whether HTTP/1 connections will accept spaces between header names + /// and the colon that follow them in responses. + /// + /// Newline codepoints (`\r` and `\n`) will be transformed to spaces when + /// parsing. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > No whitespace is allowed between the header field-name and colon. In + /// > the past, differences in the handling of such whitespace have led to + /// > security vulnerabilities in request routing and response handling. A + /// > server MUST reject any received request message that contains + /// > whitespace between a header field-name and colon with a response code + /// > of 400 (Bad Request). A proxy MUST remove any such whitespace from a + /// > response message before forwarding the message downstream. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + pub fn http1_allow_spaces_after_header_name_in_responses(&mut self, val: bool) -> &mut Self { + self.conn_builder + .http1_allow_spaces_after_header_name_in_responses(val); + self + } + + /// Set whether HTTP/1 connections will accept obsolete line folding for + /// header values. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > A server that receives an obs-fold in a request message that is not + /// > within a message/http container MUST either reject the message by + /// > sending a 400 (Bad Request), preferably with a representation + /// > explaining that obsolete line folding is unacceptable, or replace + /// > each received obs-fold with one or more SP octets prior to + /// > interpreting the field value or forwarding the message downstream. + /// + /// > A proxy or gateway that receives an obs-fold in a response message + /// > that is not within a message/http container MUST either discard the + /// > message and replace it with a 502 (Bad Gateway) response, preferably + /// > with a representation explaining that unacceptable line folding was + /// > received, or replace each received obs-fold with one or more SP + /// > octets prior to interpreting the field value or forwarding the + /// > message downstream. + /// + /// > A user agent that receives an obs-fold in a response message that is + /// > not within a message/http container MUST replace each received + /// > obs-fold with one or more SP octets prior to interpreting the field + /// > value. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + pub fn http1_allow_obsolete_multiline_headers_in_responses(&mut self, val: bool) -> &mut Self { + self.conn_builder + .http1_allow_obsolete_multiline_headers_in_responses(val); + self + } + + /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. + /// + /// This mimicks the behaviour of major browsers. You probably don't want this. + /// You should only want this if you are implementing a proxy whose main + /// purpose is to sit in front of browsers whose users access arbitrary content + /// which may be malformed, and they expect everything that works without + /// the proxy to keep working with the proxy. + /// + /// This option will prevent Hyper's client from returning an error encountered + /// when parsing a header, except if the error was caused by the character NUL + /// (ASCII code 0), as Chrome specifically always reject those. + /// + /// The ignorable errors are: + /// * empty header names; + /// * characters that are not allowed in header names, except for `\0` and `\r`; + /// * when `allow_spaces_after_header_name_in_responses` is not enabled, + /// spaces and tabs between the header name and the colon; + /// * missing colon between header name and colon; + /// * characters that are not allowed in header values except for `\0` and `\r`. + /// + /// If an ignorable error is encountered, the parser tries to find the next + /// line in the input to resume parsing the rest of the headers. An error + /// will be emitted nonetheless if it finds `\0` or a lone `\r` while + /// looking for the next line. + pub fn http1_ignore_invalid_headers_in_responses( + &mut self, + val: bool, + ) -> &mut Builder { + self.conn_builder + .http1_ignore_invalid_headers_in_responses(val); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + pub fn http1_writev(&mut self, enabled: bool) -> &mut Builder { + self.conn_builder.http1_writev(enabled); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_title_case_headers(&mut self, val: bool) -> &mut Self { + self.conn_builder.http1_title_case_headers(val); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Response`. It will also look for and use + /// such an extension in any provided `Request`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_preserve_header_case(&mut self, val: bool) -> &mut Self { + self.conn_builder.http1_preserve_header_case(val); + self + } + + /// Set whether HTTP/0.9 responses should be tolerated. + /// + /// Default is false. + pub fn http09_responses(&mut self, val: bool) -> &mut Self { + self.conn_builder.http09_responses(val); + self + } + + /// Set whether the connection **must** use HTTP/2. + /// + /// The destination must either allow HTTP2 Prior Knowledge, or the + /// `Connect` should be configured to do use ALPN to upgrade to `h2` + /// as part of the connection process. This will not make the `Client` + /// utilize ALPN by itself. + /// + /// Note that setting this to true prevents HTTP/1 from being allowed. + /// + /// Default is false. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_only(&mut self, val: bool) -> &mut Self { + self.client_config.ver = if val { Ver::Http2 } else { Ver::Auto }; + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.conn_builder + .http2_initial_stream_window_size(sz.into()); + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + self.conn_builder + .http2_initial_connection_window_size(sz.into()); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.conn_builder.http2_adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.conn_builder.http2_max_frame_size(sz); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.conn_builder.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.conn_builder.http2_keep_alive_timeout(timeout); + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.conn_builder.http2_keep_alive_while_idle(enabled); + self + } + + /// Sets the maximum number of HTTP2 concurrent locally reset streams. + /// + /// See the documentation of [`h2::client::Builder::max_concurrent_reset_streams`] for more + /// details. + /// + /// The default value is determined by the `h2` crate. + /// + /// [`h2::client::Builder::max_concurrent_reset_streams`]: https://docs.rs/h2/client/struct.Builder.html#method.max_concurrent_reset_streams + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_concurrent_reset_streams(&mut self, max: usize) -> &mut Self { + self.conn_builder.http2_max_concurrent_reset_streams(max); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently 1MB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { + self.conn_builder.http2_max_send_buf_size(max); + self + } + + /// Set whether to retry requests that get disrupted before ever starting + /// to write. + /// + /// This means a request that is queued, and gets given an idle, reused + /// connection, and then encounters an error immediately as the idle + /// connection was found to be unusable. + /// + /// When this is set to `false`, the related `ResponseFuture` would instead + /// resolve to an `Error::Cancel`. + /// + /// Default is `true`. + #[inline] + pub fn retry_canceled_requests(&mut self, val: bool) -> &mut Self { + self.client_config.retry_canceled_requests = val; + self + } + + /// Set whether to automatically add the `Host` header to requests. + /// + /// If true, and a request does not include a `Host` header, one will be + /// added automatically, derived from the authority of the `Uri`. + /// + /// Default is `true`. + #[inline] + pub fn set_host(&mut self, val: bool) -> &mut Self { + self.client_config.set_host = val; + self + } + + /// Provide an executor to execute background `Connection` tasks. + pub fn executor<E>(&mut self, exec: E) -> &mut Self + where + E: Executor<BoxSendFuture> + Send + Sync + 'static, + { + self.conn_builder.executor(exec); + self + } + + /// Builder a client with this configuration and the default `HttpConnector`. + #[cfg(feature = "tcp")] + pub fn build_http<B>(&self) -> Client<HttpConnector, B> + where + B: HttpBody + Send, + B::Data: Send, + { + let mut connector = HttpConnector::new(); + if self.pool_config.is_enabled() { + connector.set_keepalive(self.pool_config.idle_timeout); + } + self.build(connector) + } + + /// Combine the configuration of this builder with a connector to create a `Client`. + pub fn build<C, B>(&self, connector: C) -> Client<C, B> + where + C: Connect + Clone, + B: HttpBody + Send, + B::Data: Send, + { + Client { + config: self.client_config, + conn_builder: self.conn_builder.clone(), + connector, + pool: Pool::new(self.pool_config, &self.conn_builder.exec), + } + } +} + +impl fmt::Debug for Builder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Builder") + .field("client_config", &self.client_config) + .field("conn_builder", &self.conn_builder) + .field("pool_config", &self.pool_config) + .finish() + } +} + +#[cfg(test)] +mod unit_tests { + use super::*; + + #[test] + fn response_future_is_sync() { + fn assert_sync<T: Sync>() {} + assert_sync::<ResponseFuture>(); + } + + #[test] + fn set_relative_uri_with_implicit_path() { + let mut uri = "http://hyper.rs".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/"); + } + + #[test] + fn test_origin_form() { + let mut uri = "http://hyper.rs/guides".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/guides"); + + let mut uri = "http://hyper.rs/guides?foo=bar".parse().unwrap(); + origin_form(&mut uri); + assert_eq!(uri.to_string(), "/guides?foo=bar"); + } + + #[test] + fn test_absolute_form() { + let mut uri = "http://hyper.rs/guides".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri.to_string(), "http://hyper.rs/guides"); + + let mut uri = "https://hyper.rs/guides".parse().unwrap(); + absolute_form(&mut uri); + assert_eq!(uri.to_string(), "/guides"); + } + + #[test] + fn test_authority_form() { + let _ = pretty_env_logger::try_init(); + + let mut uri = "http://hyper.rs".parse().unwrap(); + authority_form(&mut uri); + assert_eq!(uri.to_string(), "hyper.rs"); + + let mut uri = "hyper.rs".parse().unwrap(); + authority_form(&mut uri); + assert_eq!(uri.to_string(), "hyper.rs"); + } + + #[test] + fn test_extract_domain_connect_no_port() { + let mut uri = "hyper.rs".parse().unwrap(); + let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain"); + assert_eq!(scheme, *"http"); + assert_eq!(host, "hyper.rs"); + } + + #[test] + fn test_is_secure() { + assert_eq!( + is_schema_secure(&"http://hyper.rs".parse::<Uri>().unwrap()), + false + ); + assert_eq!(is_schema_secure(&"hyper.rs".parse::<Uri>().unwrap()), false); + assert_eq!( + is_schema_secure(&"wss://hyper.rs".parse::<Uri>().unwrap()), + true + ); + assert_eq!( + is_schema_secure(&"ws://hyper.rs".parse::<Uri>().unwrap()), + false + ); + } + + #[test] + fn test_get_non_default_port() { + assert!(get_non_default_port(&"http://hyper.rs".parse::<Uri>().unwrap()).is_none()); + assert!(get_non_default_port(&"http://hyper.rs:80".parse::<Uri>().unwrap()).is_none()); + assert!(get_non_default_port(&"https://hyper.rs:443".parse::<Uri>().unwrap()).is_none()); + assert!(get_non_default_port(&"hyper.rs:80".parse::<Uri>().unwrap()).is_none()); + + assert_eq!( + get_non_default_port(&"http://hyper.rs:123".parse::<Uri>().unwrap()) + .unwrap() + .as_u16(), + 123 + ); + assert_eq!( + get_non_default_port(&"https://hyper.rs:80".parse::<Uri>().unwrap()) + .unwrap() + .as_u16(), + 80 + ); + assert_eq!( + get_non_default_port(&"hyper.rs:123".parse::<Uri>().unwrap()) + .unwrap() + .as_u16(), + 123 + ); + } +} diff --git a/third_party/rust/hyper/src/client/conn.rs b/third_party/rust/hyper/src/client/conn.rs new file mode 100644 index 0000000000..3eb12b4204 --- /dev/null +++ b/third_party/rust/hyper/src/client/conn.rs @@ -0,0 +1,1113 @@ +//! Lower-level client connection API. +//! +//! The types in this module are to provide a lower-level API based around a +//! single connection. Connecting to a host, pooling connections, and the like +//! are not handled at this level. This module provides the building blocks to +//! customize those things externally. +//! +//! If don't have need to manage connections yourself, consider using the +//! higher-level [Client](super) API. +//! +//! ## Example +//! A simple example that uses the `SendRequest` struct to talk HTTP over a Tokio TCP stream +//! ```no_run +//! # #[cfg(all(feature = "client", feature = "http1", feature = "runtime"))] +//! # mod rt { +//! use tower::ServiceExt; +//! use http::{Request, StatusCode}; +//! use hyper::{client::conn, Body}; +//! use tokio::net::TcpStream; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let target_stream = TcpStream::connect("example.com:80").await?; +//! +//! let (mut request_sender, connection) = conn::handshake(target_stream).await?; +//! +//! // spawn a task to poll the connection and drive the HTTP state +//! tokio::spawn(async move { +//! if let Err(e) = connection.await { +//! eprintln!("Error in connection: {}", e); +//! } +//! }); +//! +//! let request = Request::builder() +//! // We need to manually add the host header because SendRequest does not +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Body::from(""))?; +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! +//! // To send via the same connection again, it may not work as it may not be ready, +//! // so we have to wait until the request_sender becomes ready. +//! request_sender.ready().await?; +//! let request = Request::builder() +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Body::from(""))?; +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! Ok(()) +//! } +//! +//! # } +//! ``` + +use std::error::Error as StdError; +use std::fmt; +#[cfg(not(all(feature = "http1", feature = "http2")))] +use std::marker::PhantomData; +use std::sync::Arc; +#[cfg(all(feature = "runtime", feature = "http2"))] +use std::time::Duration; + +use bytes::Bytes; +use futures_util::future::{self, Either, FutureExt as _}; +use httparse::ParserConfig; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower_service::Service; +use tracing::{debug, trace}; + +use super::dispatch; +use crate::body::HttpBody; +#[cfg(not(all(feature = "http1", feature = "http2")))] +use crate::common::Never; +use crate::common::{ + exec::{BoxSendFuture, Exec}, + task, Future, Pin, Poll, +}; +use crate::proto; +use crate::rt::Executor; +#[cfg(feature = "http1")] +use crate::upgrade::Upgraded; +use crate::{Body, Request, Response}; + +#[cfg(feature = "http1")] +type Http1Dispatcher<T, B> = + proto::dispatch::Dispatcher<proto::dispatch::Client<B>, B, T, proto::h1::ClientTransaction>; + +#[cfg(not(feature = "http1"))] +type Http1Dispatcher<T, B> = (Never, PhantomData<(T, Pin<Box<B>>)>); + +#[cfg(feature = "http2")] +type Http2ClientTask<B> = proto::h2::ClientTask<B>; + +#[cfg(not(feature = "http2"))] +type Http2ClientTask<B> = (Never, PhantomData<Pin<Box<B>>>); + +pin_project! { + #[project = ProtoClientProj] + enum ProtoClient<T, B> + where + B: HttpBody, + { + H1 { + #[pin] + h1: Http1Dispatcher<T, B>, + }, + H2 { + #[pin] + h2: Http2ClientTask<B>, + }, + } +} + +/// Returns a handshake future over some IO. +/// +/// This is a shortcut for `Builder::new().handshake(io)`. +/// See [`client::conn`](crate::client::conn) for more. +pub async fn handshake<T>( + io: T, +) -> crate::Result<(SendRequest<crate::Body>, Connection<T, crate::Body>)> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + Builder::new().handshake(io).await +} + +/// The sender side of an established connection. +pub struct SendRequest<B> { + dispatch: dispatch::Sender<Request<B>, Response<Body>>, +} + +/// A future that processes all HTTP state for the IO object. +/// +/// In most cases, this should just be spawned into an executor, so that it +/// can process incoming and outgoing messages, notice hangups, and the like. +#[must_use = "futures do nothing unless polled"] +pub struct Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: HttpBody + 'static, +{ + inner: Option<ProtoClient<T, B>>, +} + +/// A builder to configure an HTTP connection. +/// +/// After setting options, the builder is used to create a handshake future. +#[derive(Clone, Debug)] +pub struct Builder { + pub(super) exec: Exec, + h09_responses: bool, + h1_parser_config: ParserConfig, + h1_writev: Option<bool>, + h1_title_case_headers: bool, + h1_preserve_header_case: bool, + #[cfg(feature = "ffi")] + h1_preserve_header_order: bool, + h1_read_buf_exact_size: Option<usize>, + h1_max_buf_size: Option<usize>, + #[cfg(feature = "ffi")] + h1_headers_raw: bool, + #[cfg(feature = "http2")] + h2_builder: proto::h2::client::Config, + version: Proto, +} + +#[derive(Clone, Debug)] +enum Proto { + #[cfg(feature = "http1")] + Http1, + #[cfg(feature = "http2")] + Http2, +} + +/// A future returned by `SendRequest::send_request`. +/// +/// Yields a `Response` if successful. +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + inner: ResponseFutureState, +} + +enum ResponseFutureState { + Waiting(dispatch::Promise<Response<Body>>), + // Option is to be able to `take()` it in `poll` + Error(Option<crate::Error>), +} + +/// Deconstructed parts of a `Connection`. +/// +/// This allows taking apart a `Connection` at a later time, in order to +/// reclaim the IO object, and additional related pieces. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used in the handshake. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +// ========== internal client api + +// A `SendRequest` that can be cloned to send HTTP2 requests. +// private for now, probably not a great idea of a type... +#[must_use = "futures do nothing unless polled"] +#[cfg(feature = "http2")] +pub(super) struct Http2SendRequest<B> { + dispatch: dispatch::UnboundedSender<Request<B>, Response<Body>>, +} + +// ===== impl SendRequest + +impl<B> SendRequest<B> { + /// Polls to determine whether this sender can be used yet for a request. + /// + /// If the associated connection is closed, this returns an Error. + pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.dispatch.poll_ready(cx) + } + + pub(super) async fn when_ready(self) -> crate::Result<Self> { + let mut me = Some(self); + future::poll_fn(move |cx| { + ready!(me.as_mut().unwrap().poll_ready(cx))?; + Poll::Ready(Ok(me.take().unwrap())) + }) + .await + } + + pub(super) fn is_ready(&self) -> bool { + self.dispatch.is_ready() + } + + pub(super) fn is_closed(&self) -> bool { + self.dispatch.is_closed() + } + + #[cfg(feature = "http2")] + pub(super) fn into_http2(self) -> Http2SendRequest<B> { + Http2SendRequest { + dispatch: self.dispatch.unbound(), + } + } +} + +impl<B> SendRequest<B> +where + B: HttpBody + 'static, +{ + /// Sends a `Request` on the associated connection. + /// + /// Returns a future that if successful, yields the `Response`. + /// + /// # Note + /// + /// There are some key differences in what automatic things the `Client` + /// does for you that will not be done here: + /// + /// - `Client` requires absolute-form `Uri`s, since the scheme and + /// authority are needed to connect. They aren't required here. + /// - Since the `Client` requires absolute-form `Uri`s, it can add + /// the `Host` header based on it. You must add a `Host` header yourself + /// before calling this method. + /// - Since absolute-form `Uri`s are not required, if received, they will + /// be serialized as-is. + /// + /// # Example + /// + /// ``` + /// # use http::header::HOST; + /// # use hyper::client::conn::SendRequest; + /// # use hyper::Body; + /// use hyper::Request; + /// + /// # async fn doc(mut tx: SendRequest<Body>) -> hyper::Result<()> { + /// // build a Request + /// let req = Request::builder() + /// .uri("/foo/bar") + /// .header(HOST, "hyper.rs") + /// .body(Body::empty()) + /// .unwrap(); + /// + /// // send it and await a Response + /// let res = tx.send_request(req).await?; + /// // assert the Response + /// assert!(res.status().is_success()); + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn send_request(&mut self, req: Request<B>) -> ResponseFuture { + let inner = match self.dispatch.send(req) { + Ok(rx) => ResponseFutureState::Waiting(rx), + Err(_req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + ResponseFutureState::Error(Some(err)) + } + }; + + ResponseFuture { inner } + } + + pub(super) fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + Unpin + where + B: Send, + { + match self.dispatch.try_send(req) { + Ok(rx) => { + Either::Left(rx.then(move |res| { + match res { + Ok(Ok(res)) => future::ok(res), + Ok(Err(err)) => future::err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + } + Err(req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + Either::Right(future::err((err, Some(req)))) + } + } + } +} + +impl<B> Service<Request<B>> for SendRequest<B> +where + B: HttpBody + 'static, +{ + type Response = Response<Body>; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + self.poll_ready(cx) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.send_request(req) + } +} + +impl<B> fmt::Debug for SendRequest<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendRequest").finish() + } +} + +// ===== impl Http2SendRequest + +#[cfg(feature = "http2")] +impl<B> Http2SendRequest<B> { + pub(super) fn is_ready(&self) -> bool { + self.dispatch.is_ready() + } + + pub(super) fn is_closed(&self) -> bool { + self.dispatch.is_closed() + } +} + +#[cfg(feature = "http2")] +impl<B> Http2SendRequest<B> +where + B: HttpBody + 'static, +{ + pub(super) fn send_request_retryable( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<Body>, (crate::Error, Option<Request<B>>)>> + where + B: Send, + { + match self.dispatch.try_send(req) { + Ok(rx) => { + Either::Left(rx.then(move |res| { + match res { + Ok(Ok(res)) => future::ok(res), + Ok(Err(err)) => future::err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + } + Err(req) => { + debug!("connection was not ready"); + let err = crate::Error::new_canceled().with("connection was not ready"); + Either::Right(future::err((err, Some(req)))) + } + } + } +} + +#[cfg(feature = "http2")] +impl<B> fmt::Debug for Http2SendRequest<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Http2SendRequest").finish() + } +} + +#[cfg(feature = "http2")] +impl<B> Clone for Http2SendRequest<B> { + fn clone(&self) -> Self { + Http2SendRequest { + dispatch: self.dispatch.clone(), + } + } +} + +// ===== impl Connection + +impl<T, B> Connection<T, B> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: HttpBody + Unpin + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + /// Return the inner IO object, and additional information. + /// + /// Only works for HTTP/1 connections. HTTP/2 connections will panic. + pub fn into_parts(self) -> Parts<T> { + match self.inner.expect("already upgraded") { + #[cfg(feature = "http1")] + ProtoClient::H1 { h1 } => { + let (io, read_buf, _) = h1.into_inner(); + Parts { + io, + read_buf, + _inner: (), + } + } + ProtoClient::H2 { .. } => { + panic!("http2 cannot into_inner"); + } + + #[cfg(not(feature = "http1"))] + ProtoClient::H1 { h1 } => match h1.0 {}, + } + } + + /// Poll the connection for completion, but without calling `shutdown` + /// on the underlying IO. + /// + /// This is useful to allow running a connection while doing an HTTP + /// upgrade. Once the upgrade is completed, the connection would be "done", + /// but it is not desired to actually shutdown the IO object. Instead you + /// would take it back using `into_parts`. + /// + /// Use [`poll_fn`](https://docs.rs/futures/0.1.25/futures/future/fn.poll_fn.html) + /// and [`try_ready!`](https://docs.rs/futures/0.1.25/futures/macro.try_ready.html) + /// to work with this function; or use the `without_shutdown` wrapper. + pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + match *self.inner.as_mut().expect("already upgraded") { + #[cfg(feature = "http1")] + ProtoClient::H1 { ref mut h1 } => h1.poll_without_shutdown(cx), + #[cfg(feature = "http2")] + ProtoClient::H2 { ref mut h2, .. } => Pin::new(h2).poll(cx).map_ok(|_| ()), + + #[cfg(not(feature = "http1"))] + ProtoClient::H1 { ref mut h1 } => match h1.0 {}, + #[cfg(not(feature = "http2"))] + ProtoClient::H2 { ref mut h2, .. } => match h2.0 {}, + } + } + + /// Prevent shutdown of the underlying IO object at the end of service the request, + /// instead run `into_parts`. This is a convenience wrapper over `poll_without_shutdown`. + pub fn without_shutdown(self) -> impl Future<Output = crate::Result<Parts<T>>> { + let mut conn = Some(self); + future::poll_fn(move |cx| -> Poll<crate::Result<Parts<T>>> { + ready!(conn.as_mut().unwrap().poll_without_shutdown(cx))?; + Poll::Ready(Ok(conn.take().unwrap().into_parts())) + }) + } + + /// Returns whether the [extended CONNECT protocol][1] is enabled or not. + /// + /// This setting is configured by the server peer by sending the + /// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value received from the + /// remote. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3 + #[cfg(feature = "http2")] + pub fn http2_is_extended_connect_protocol_enabled(&self) -> bool { + match self.inner.as_ref().unwrap() { + ProtoClient::H1 { .. } => false, + ProtoClient::H2 { h2 } => h2.is_extended_connect_protocol_enabled(), + } + } +} + +impl<T, B> Future for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match ready!(Pin::new(self.inner.as_mut().unwrap()).poll(cx))? { + proto::Dispatched::Shutdown => Poll::Ready(Ok(())), + #[cfg(feature = "http1")] + proto::Dispatched::Upgrade(pending) => match self.inner.take() { + Some(ProtoClient::H1 { h1 }) => { + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(io, buf)); + Poll::Ready(Ok(())) + } + _ => { + drop(pending); + unreachable!("Upgrade expects h1"); + } + }, + } + } +} + +impl<T, B> fmt::Debug for Connection<T, B> +where + T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + B: HttpBody + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").finish() + } +} + +// ===== impl Builder + +impl Builder { + /// Creates a new connection builder. + #[inline] + pub fn new() -> Builder { + Builder { + exec: Exec::Default, + h09_responses: false, + h1_writev: None, + h1_read_buf_exact_size: None, + h1_parser_config: Default::default(), + h1_title_case_headers: false, + h1_preserve_header_case: false, + #[cfg(feature = "ffi")] + h1_preserve_header_order: false, + h1_max_buf_size: None, + #[cfg(feature = "ffi")] + h1_headers_raw: false, + #[cfg(feature = "http2")] + h2_builder: Default::default(), + #[cfg(feature = "http1")] + version: Proto::Http1, + #[cfg(not(feature = "http1"))] + version: Proto::Http2, + } + } + + /// Provide an executor to execute background HTTP2 tasks. + pub fn executor<E>(&mut self, exec: E) -> &mut Builder + where + E: Executor<BoxSendFuture> + Send + Sync + 'static, + { + self.exec = Exec::Executor(Arc::new(exec)); + self + } + + /// Set whether HTTP/0.9 responses should be tolerated. + /// + /// Default is false. + pub fn http09_responses(&mut self, enabled: bool) -> &mut Builder { + self.h09_responses = enabled; + self + } + + /// Set whether HTTP/1 connections will accept spaces between header names + /// and the colon that follow them in responses. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > No whitespace is allowed between the header field-name and colon. In + /// > the past, differences in the handling of such whitespace have led to + /// > security vulnerabilities in request routing and response handling. A + /// > server MUST reject any received request message that contains + /// > whitespace between a header field-name and colon with a response code + /// > of 400 (Bad Request). A proxy MUST remove any such whitespace from a + /// > response message before forwarding the message downstream. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + pub fn http1_allow_spaces_after_header_name_in_responses( + &mut self, + enabled: bool, + ) -> &mut Builder { + self.h1_parser_config + .allow_spaces_after_header_name_in_responses(enabled); + self + } + + /// Set whether HTTP/1 connections will accept obsolete line folding for + /// header values. + /// + /// Newline codepoints (`\r` and `\n`) will be transformed to spaces when + /// parsing. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > A server that receives an obs-fold in a request message that is not + /// > within a message/http container MUST either reject the message by + /// > sending a 400 (Bad Request), preferably with a representation + /// > explaining that obsolete line folding is unacceptable, or replace + /// > each received obs-fold with one or more SP octets prior to + /// > interpreting the field value or forwarding the message downstream. + /// + /// > A proxy or gateway that receives an obs-fold in a response message + /// > that is not within a message/http container MUST either discard the + /// > message and replace it with a 502 (Bad Gateway) response, preferably + /// > with a representation explaining that unacceptable line folding was + /// > received, or replace each received obs-fold with one or more SP + /// > octets prior to interpreting the field value or forwarding the + /// > message downstream. + /// + /// > A user agent that receives an obs-fold in a response message that is + /// > not within a message/http container MUST replace each received + /// > obs-fold with one or more SP octets prior to interpreting the field + /// > value. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + pub fn http1_allow_obsolete_multiline_headers_in_responses( + &mut self, + enabled: bool, + ) -> &mut Builder { + self.h1_parser_config + .allow_obsolete_multiline_headers_in_responses(enabled); + self + } + + /// Set whether HTTP/1 connections will silently ignored malformed header lines. + /// + /// If this is enabled and and a header line does not start with a valid header + /// name, or does not include a colon at all, the line will be silently ignored + /// and no error will be reported. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_ignore_invalid_headers_in_responses( + &mut self, + enabled: bool, + ) -> &mut Builder { + self.h1_parser_config + .ignore_invalid_headers_in_responses(enabled); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + pub fn http1_writev(&mut self, enabled: bool) -> &mut Builder { + self.h1_writev = Some(enabled); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_title_case_headers(&mut self, enabled: bool) -> &mut Builder { + self.h1_title_case_headers = enabled; + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Response`. It will also look for and use + /// such an extension in any provided `Request`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_preserve_header_case(&mut self, enabled: bool) -> &mut Builder { + self.h1_preserve_header_case = enabled; + self + } + + /// Set whether to support preserving original header order. + /// + /// Currently, this will record the order in which headers are received, and store this + /// ordering in a private extension on the `Response`. It will also look for and use + /// such an extension in any provided `Request`. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "ffi")] + pub fn http1_preserve_header_order(&mut self, enabled: bool) -> &mut Builder { + self.h1_preserve_header_order = enabled; + self + } + + /// Sets the exact size of the read buffer to *always* use. + /// + /// Note that setting this option unsets the `http1_max_buf_size` option. + /// + /// Default is an adaptive read buffer. + pub fn http1_read_buf_exact_size(&mut self, sz: Option<usize>) -> &mut Builder { + self.h1_read_buf_exact_size = sz; + self.h1_max_buf_size = None; + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// Note that setting this option unsets the `http1_read_exact_buf_size` option. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_buf_size(&mut self, max: usize) -> &mut Self { + assert!( + max >= proto::h1::MINIMUM_MAX_BUFFER_SIZE, + "the max_buf_size cannot be smaller than the minimum that h1 specifies." + ); + + self.h1_max_buf_size = Some(max); + self.h1_read_buf_exact_size = None; + self + } + + #[cfg(feature = "ffi")] + pub(crate) fn http1_headers_raw(&mut self, enabled: bool) -> &mut Self { + self.h1_headers_raw = enabled; + self + } + + /// Sets whether HTTP2 is required. + /// + /// Default is false. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_only(&mut self, enabled: bool) -> &mut Builder { + if enabled { + self.version = Proto::Http2 + } + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_stream_window_size = sz; + } + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_conn_window_size = sz; + } + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + use proto::h2::SPEC_WINDOW_SIZE; + + self.h2_builder.adaptive_window = enabled; + if enabled { + self.h2_builder.initial_conn_window_size = SPEC_WINDOW_SIZE; + self.h2_builder.initial_stream_window_size = SPEC_WINDOW_SIZE; + } + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.max_frame_size = sz; + } + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.h2_builder.keep_alive_while_idle = enabled; + self + } + + /// Sets the maximum number of HTTP2 concurrent locally reset streams. + /// + /// See the documentation of [`h2::client::Builder::max_concurrent_reset_streams`] for more + /// details. + /// + /// The default value is determined by the `h2` crate. + /// + /// [`h2::client::Builder::max_concurrent_reset_streams`]: https://docs.rs/h2/client/struct.Builder.html#method.max_concurrent_reset_streams + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_concurrent_reset_streams(&mut self, max: usize) -> &mut Self { + self.h2_builder.max_concurrent_reset_streams = Some(max); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently 1MB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.h2_builder.max_send_buffer_size = max; + self + } + + /// Constructs a connection with the configured options and IO. + /// See [`client::conn`](crate::client::conn) for more. + /// + /// Note, if [`Connection`] is not `await`-ed, [`SendRequest`] will + /// do nothing. + pub fn handshake<T, B>( + &self, + io: T, + ) -> impl Future<Output = crate::Result<(SendRequest<B>, Connection<T, B>)>> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: HttpBody + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let opts = self.clone(); + + async move { + trace!("client handshake {:?}", opts.version); + + let (tx, rx) = dispatch::channel(); + let proto = match opts.version { + #[cfg(feature = "http1")] + Proto::Http1 => { + let mut conn = proto::Conn::new(io); + conn.set_h1_parser_config(opts.h1_parser_config); + if let Some(writev) = opts.h1_writev { + if writev { + conn.set_write_strategy_queue(); + } else { + conn.set_write_strategy_flatten(); + } + } + if opts.h1_title_case_headers { + conn.set_title_case_headers(); + } + if opts.h1_preserve_header_case { + conn.set_preserve_header_case(); + } + #[cfg(feature = "ffi")] + if opts.h1_preserve_header_order { + conn.set_preserve_header_order(); + } + if opts.h09_responses { + conn.set_h09_responses(); + } + + #[cfg(feature = "ffi")] + conn.set_raw_headers(opts.h1_headers_raw); + + if let Some(sz) = opts.h1_read_buf_exact_size { + conn.set_read_buf_exact_size(sz); + } + if let Some(max) = opts.h1_max_buf_size { + conn.set_max_buf_size(max); + } + let cd = proto::h1::dispatch::Client::new(rx); + let dispatch = proto::h1::Dispatcher::new(cd, conn); + ProtoClient::H1 { h1: dispatch } + } + #[cfg(feature = "http2")] + Proto::Http2 => { + let h2 = + proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone()) + .await?; + ProtoClient::H2 { h2 } + } + }; + + Ok(( + SendRequest { dispatch: tx }, + Connection { inner: Some(proto) }, + )) + } + } +} + +// ===== impl ResponseFuture + +impl Future for ResponseFuture { + type Output = crate::Result<Response<Body>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.inner { + ResponseFutureState::Waiting(ref mut rx) => { + Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(resp)) => Ok(resp), + Ok(Err(err)) => Err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_canceled) => panic!("dispatch dropped without returning error"), + }) + } + ResponseFutureState::Error(ref mut err) => { + Poll::Ready(Err(err.take().expect("polled after ready"))) + } + } + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + +// ===== impl ProtoClient + +impl<T, B> Future for ProtoClient<T, B> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<proto::Dispatched>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.project() { + #[cfg(feature = "http1")] + ProtoClientProj::H1 { h1 } => h1.poll(cx), + #[cfg(feature = "http2")] + ProtoClientProj::H2 { h2, .. } => h2.poll(cx), + + #[cfg(not(feature = "http1"))] + ProtoClientProj::H1 { h1 } => match h1.0 {}, + #[cfg(not(feature = "http2"))] + ProtoClientProj::H2 { h2, .. } => match h2.0 {}, + } + } +} + +// assert trait markers + +trait AssertSend: Send {} +trait AssertSendSync: Send + Sync {} + +#[doc(hidden)] +impl<B: Send> AssertSendSync for SendRequest<B> {} + +#[doc(hidden)] +impl<T: Send, B: Send> AssertSend for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: HttpBody + 'static, + B::Data: Send, +{ +} + +#[doc(hidden)] +impl<T: Send + Sync, B: Send + Sync> AssertSendSync for Connection<T, B> +where + T: AsyncRead + AsyncWrite + Send + 'static, + B: HttpBody + 'static, + B::Data: Send + Sync + 'static, +{ +} + +#[doc(hidden)] +impl AssertSendSync for Builder {} + +#[doc(hidden)] +impl AssertSend for ResponseFuture {} diff --git a/third_party/rust/hyper/src/client/connect/dns.rs b/third_party/rust/hyper/src/client/connect/dns.rs new file mode 100644 index 0000000000..e4465078b3 --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/dns.rs @@ -0,0 +1,425 @@ +//! DNS Resolution used by the `HttpConnector`. +//! +//! This module contains: +//! +//! - A [`GaiResolver`](GaiResolver) that is the default resolver for the +//! `HttpConnector`. +//! - The `Name` type used as an argument to custom resolvers. +//! +//! # Resolvers are `Service`s +//! +//! A resolver is just a +//! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`. +//! +//! A simple resolver that ignores the name and always returns a specific +//! address: +//! +//! ```rust,ignore +//! use std::{convert::Infallible, iter, net::SocketAddr}; +//! +//! let resolver = tower::service_fn(|_name| async { +//! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) +//! }); +//! ``` +use std::error::Error; +use std::future::Future; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{self, Poll}; +use std::{fmt, io, vec}; + +use tokio::task::JoinHandle; +use tower_service::Service; +use tracing::debug; + +pub(super) use self::sealed::Resolve; + +/// A domain name to resolve into IP addresses. +#[derive(Clone, Hash, Eq, PartialEq)] +pub struct Name { + host: Box<str>, +} + +/// A resolver using blocking `getaddrinfo` calls in a threadpool. +#[derive(Clone)] +pub struct GaiResolver { + _priv: (), +} + +/// An iterator of IP addresses returned from `getaddrinfo`. +pub struct GaiAddrs { + inner: SocketAddrs, +} + +/// A future to resolve a name returned by `GaiResolver`. +pub struct GaiFuture { + inner: JoinHandle<Result<SocketAddrs, io::Error>>, +} + +impl Name { + pub(super) fn new(host: Box<str>) -> Name { + Name { host } + } + + /// View the hostname as a string slice. + pub fn as_str(&self) -> &str { + &self.host + } +} + +impl fmt::Debug for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.host, f) + } +} + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.host, f) + } +} + +impl FromStr for Name { + type Err = InvalidNameError; + + fn from_str(host: &str) -> Result<Self, Self::Err> { + // Possibly add validation later + Ok(Name::new(host.into())) + } +} + +/// Error indicating a given string was not a valid domain name. +#[derive(Debug)] +pub struct InvalidNameError(()); + +impl fmt::Display for InvalidNameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Not a valid domain name") + } +} + +impl Error for InvalidNameError {} + +impl GaiResolver { + /// Construct a new `GaiResolver`. + pub fn new() -> Self { + GaiResolver { _priv: () } + } +} + +impl Service<Name> for GaiResolver { + type Response = GaiAddrs; + type Error = io::Error; + type Future = GaiFuture; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + let blocking = tokio::task::spawn_blocking(move || { + debug!("resolving host={:?}", name.host); + (&*name.host, 0) + .to_socket_addrs() + .map(|i| SocketAddrs { iter: i }) + }); + + GaiFuture { inner: blocking } + } +} + +impl fmt::Debug for GaiResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiResolver") + } +} + +impl Future for GaiFuture { + type Output = Result<GaiAddrs, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.inner).poll(cx).map(|res| match res { + Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }), + Ok(Err(err)) => Err(err), + Err(join_err) => { + if join_err.is_cancelled() { + Err(io::Error::new(io::ErrorKind::Interrupted, join_err)) + } else { + panic!("gai background task failed: {:?}", join_err) + } + } + }) + } +} + +impl fmt::Debug for GaiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiFuture") + } +} + +impl Drop for GaiFuture { + fn drop(&mut self) { + self.inner.abort(); + } +} + +impl Iterator for GaiAddrs { + type Item = SocketAddr; + + fn next(&mut self) -> Option<Self::Item> { + self.inner.next() + } +} + +impl fmt::Debug for GaiAddrs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiAddrs") + } +} + +pub(super) struct SocketAddrs { + iter: vec::IntoIter<SocketAddr>, +} + +impl SocketAddrs { + pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { + SocketAddrs { + iter: addrs.into_iter(), + } + } + + pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> { + if let Ok(addr) = host.parse::<Ipv4Addr>() { + let addr = SocketAddrV4::new(addr, port); + return Some(SocketAddrs { + iter: vec![SocketAddr::V4(addr)].into_iter(), + }); + } + if let Ok(addr) = host.parse::<Ipv6Addr>() { + let addr = SocketAddrV6::new(addr, port, 0, 0); + return Some(SocketAddrs { + iter: vec![SocketAddr::V6(addr)].into_iter(), + }); + } + None + } + + #[inline] + fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { + SocketAddrs::new(self.iter.filter(predicate).collect()) + } + + pub(super) fn split_by_preference( + self, + local_addr_ipv4: Option<Ipv4Addr>, + local_addr_ipv6: Option<Ipv6Addr>, + ) -> (SocketAddrs, SocketAddrs) { + match (local_addr_ipv4, local_addr_ipv6) { + (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), + (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), + _ => { + let preferring_v6 = self + .iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self + .iter + .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); + + (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) + } + } + } + + pub(super) fn is_empty(&self) -> bool { + self.iter.as_slice().is_empty() + } + + pub(super) fn len(&self) -> usize { + self.iter.as_slice().len() + } +} + +impl Iterator for SocketAddrs { + type Item = SocketAddr; + #[inline] + fn next(&mut self) -> Option<SocketAddr> { + self.iter.next() + } +} + +/* +/// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API. +/// +/// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the +/// multi-threaded Tokio runtime. +#[cfg(feature = "runtime")] +#[derive(Clone, Debug)] +pub struct TokioThreadpoolGaiResolver(()); + +/// The future returned by `TokioThreadpoolGaiResolver`. +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub struct TokioThreadpoolGaiFuture { + name: Name, +} + +#[cfg(feature = "runtime")] +impl TokioThreadpoolGaiResolver { + /// Creates a new DNS resolver that will use tokio threadpool's blocking + /// feature. + /// + /// **Requires** its futures to be run on the threadpool runtime. + pub fn new() -> Self { + TokioThreadpoolGaiResolver(()) + } +} + +#[cfg(feature = "runtime")] +impl Service<Name> for TokioThreadpoolGaiResolver { + type Response = GaiAddrs; + type Error = io::Error; + type Future = TokioThreadpoolGaiFuture; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + TokioThreadpoolGaiFuture { name } + } +} + +#[cfg(feature = "runtime")] +impl Future for TokioThreadpoolGaiFuture { + type Output = Result<GaiAddrs, io::Error>; + + fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match ready!(tokio_executor::threadpool::blocking(|| ( + self.name.as_str(), + 0 + ) + .to_socket_addrs())) + { + Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { + inner: IpAddrs { iter }, + })), + Ok(Err(e)) => Poll::Ready(Err(e)), + // a BlockingError, meaning not on a tokio_executor::threadpool :( + Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), + } + } +} +*/ + +mod sealed { + use super::{SocketAddr, Name}; + use crate::common::{task, Future, Poll}; + use tower_service::Service; + + // "Trait alias" for `Service<Name, Response = Addrs>` + pub trait Resolve { + type Addrs: Iterator<Item = SocketAddr>; + type Error: Into<Box<dyn std::error::Error + Send + Sync>>; + type Future: Future<Output = Result<Self::Addrs, Self::Error>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + fn resolve(&mut self, name: Name) -> Self::Future; + } + + impl<S> Resolve for S + where + S: Service<Name>, + S::Response: Iterator<Item = SocketAddr>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + type Addrs = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Service::poll_ready(self, cx) + } + + fn resolve(&mut self, name: Name) -> Self::Future { + Service::call(self, name) + } + } +} + +pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> +where + R: Resolve, +{ + futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; + resolver.resolve(name).await +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_ip_addrs_split_by_preference() { + let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); + let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + let v4_addr = (ip_v4, 80).into(); + let v6_addr = (ip_v6, 80).into(); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.is_empty()); + + let (mut preferred, fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.is_empty()); + } + + #[test] + fn test_name_from_str() { + const DOMAIN: &str = "test.example.com"; + let name = Name::from_str(DOMAIN).expect("Should be a valid domain"); + assert_eq!(name.as_str(), DOMAIN); + assert_eq!(name.to_string(), DOMAIN); + } +} diff --git a/third_party/rust/hyper/src/client/connect/http.rs b/third_party/rust/hyper/src/client/connect/http.rs new file mode 100644 index 0000000000..afe7b155eb --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/http.rs @@ -0,0 +1,1007 @@ +use std::error::Error as StdError; +use std::fmt; +use std::future::Future; +use std::io; +use std::marker::PhantomData; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; + +use futures_util::future::Either; +use http::uri::{Scheme, Uri}; +use pin_project_lite::pin_project; +use tokio::net::{TcpSocket, TcpStream}; +use tokio::time::Sleep; +use tracing::{debug, trace, warn}; + +use super::dns::{self, resolve, GaiResolver, Resolve}; +use super::{Connected, Connection}; +//#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver; + +/// A connector for the `http` scheme. +/// +/// Performs DNS resolution in a thread pool, and then connects over TCP. +/// +/// # Note +/// +/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes +/// transport information such as the remote socket address used. +#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))] +#[derive(Clone)] +pub struct HttpConnector<R = GaiResolver> { + config: Arc<Config>, + resolver: R, +} + +/// Extra information about the transport when an HttpConnector is used. +/// +/// # Example +/// +/// ``` +/// # async fn doc() -> hyper::Result<()> { +/// use hyper::Uri; +/// use hyper::client::{Client, connect::HttpInfo}; +/// +/// let client = Client::new(); +/// let uri = Uri::from_static("http://example.com"); +/// +/// let res = client.get(uri).await?; +/// res +/// .extensions() +/// .get::<HttpInfo>() +/// .map(|info| { +/// println!("remote addr = {}", info.remote_addr()); +/// }); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Note +/// +/// If a different connector is used besides [`HttpConnector`](HttpConnector), +/// this value will not exist in the extensions. Consult that specific +/// connector to see what "extra" information it might provide to responses. +#[derive(Clone, Debug)] +pub struct HttpInfo { + remote_addr: SocketAddr, + local_addr: SocketAddr, +} + +#[derive(Clone)] +struct Config { + connect_timeout: Option<Duration>, + enforce_http: bool, + happy_eyeballs_timeout: Option<Duration>, + keep_alive_timeout: Option<Duration>, + local_address_ipv4: Option<Ipv4Addr>, + local_address_ipv6: Option<Ipv6Addr>, + nodelay: bool, + reuse_address: bool, + send_buffer_size: Option<usize>, + recv_buffer_size: Option<usize>, +} + +// ===== impl HttpConnector ===== + +impl HttpConnector { + /// Construct a new HttpConnector. + pub fn new() -> HttpConnector { + HttpConnector::new_with_resolver(GaiResolver::new()) + } +} + +/* +#[cfg(feature = "runtime")] +impl HttpConnector<TokioThreadpoolGaiResolver> { + /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`. + /// + /// This resolver **requires** the threadpool runtime to be used. + pub fn new_with_tokio_threadpool_resolver() -> Self { + HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new()) + } +} +*/ + +impl<R> HttpConnector<R> { + /// Construct a new HttpConnector. + /// + /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups. + pub fn new_with_resolver(resolver: R) -> HttpConnector<R> { + HttpConnector { + config: Arc::new(Config { + connect_timeout: None, + enforce_http: true, + happy_eyeballs_timeout: Some(Duration::from_millis(300)), + keep_alive_timeout: None, + local_address_ipv4: None, + local_address_ipv6: None, + nodelay: false, + reuse_address: false, + send_buffer_size: None, + recv_buffer_size: None, + }), + resolver, + } + } + + /// Option to enforce all `Uri`s have the `http` scheme. + /// + /// Enabled by default. + #[inline] + pub fn enforce_http(&mut self, is_enforced: bool) { + self.config_mut().enforce_http = is_enforced; + } + + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. + /// + /// If `None`, the option will not be set. + /// + /// Default is `None`. + #[inline] + pub fn set_keepalive(&mut self, dur: Option<Duration>) { + self.config_mut().keep_alive_timeout = dur; + } + + /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. + /// + /// Default is `false`. + #[inline] + pub fn set_nodelay(&mut self, nodelay: bool) { + self.config_mut().nodelay = nodelay; + } + + /// Sets the value of the SO_SNDBUF option on the socket. + #[inline] + pub fn set_send_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().send_buffer_size = size; + } + + /// Sets the value of the SO_RCVBUF option on the socket. + #[inline] + pub fn set_recv_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().recv_buffer_size = size; + } + + /// Set that all sockets are bound to the configured address before connection. + /// + /// If `None`, the sockets will not be bound. + /// + /// Default is `None`. + #[inline] + pub fn set_local_address(&mut self, addr: Option<IpAddr>) { + let (v4, v6) = match addr { + Some(IpAddr::V4(a)) => (Some(a), None), + Some(IpAddr::V6(a)) => (None, Some(a)), + _ => (None, None), + }; + + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = v4; + cfg.local_address_ipv6 = v6; + } + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + #[inline] + pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) { + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = Some(addr_ipv4); + cfg.local_address_ipv6 = Some(addr_ipv6); + } + + /// Set the connect timeout. + /// + /// If a domain resolves to multiple IP addresses, the timeout will be + /// evenly divided across them. + /// + /// Default is `None`. + #[inline] + pub fn set_connect_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().connect_timeout = dur; + } + + /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm. + /// + /// If hostname resolves to both IPv4 and IPv6 addresses and connection + /// cannot be established using preferred address family before timeout + /// elapses, then connector will in parallel attempt connection using other + /// address family. + /// + /// If `None`, parallel connection attempts are disabled. + /// + /// Default is 300 milliseconds. + /// + /// [RFC 6555]: https://tools.ietf.org/html/rfc6555 + #[inline] + pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().happy_eyeballs_timeout = dur; + } + + /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + /// + /// Default is `false`. + #[inline] + pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self { + self.config_mut().reuse_address = reuse_address; + self + } + + // private + + fn config_mut(&mut self) -> &mut Config { + // If the are HttpConnector clones, this will clone the inner + // config. So mutating the config won't ever affect previous + // clones. + Arc::make_mut(&mut self.config) + } +} + +static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; +static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; +static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; + +// R: Debug required for now to allow adding it to debug output later... +impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpConnector").finish() + } +} + +impl<R> tower_service::Service<Uri> for HttpConnector<R> +where + R: Resolve + Clone + Send + Sync + 'static, + R::Future: Send, +{ + type Response = TcpStream; + type Error = ConnectError; + type Future = HttpConnecting<R>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let mut self_ = self.clone(); + HttpConnecting { + fut: Box::pin(async move { self_.call_async(dst).await }), + _marker: PhantomData, + } + } +} + +fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> { + trace!( + "Http::connect; scheme={:?}, host={:?}, port={:?}", + dst.scheme(), + dst.host(), + dst.port(), + ); + + if config.enforce_http { + if dst.scheme() != Some(&Scheme::HTTP) { + return Err(ConnectError { + msg: INVALID_NOT_HTTP.into(), + cause: None, + }); + } + } else if dst.scheme().is_none() { + return Err(ConnectError { + msg: INVALID_MISSING_SCHEME.into(), + cause: None, + }); + } + + let host = match dst.host() { + Some(s) => s, + None => { + return Err(ConnectError { + msg: INVALID_MISSING_HOST.into(), + cause: None, + }) + } + }; + let port = match dst.port() { + Some(port) => port.as_u16(), + None => { + if dst.scheme() == Some(&Scheme::HTTPS) { + 443 + } else { + 80 + } + } + }; + + Ok((host, port)) +} + +impl<R> HttpConnector<R> +where + R: Resolve, +{ + async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> { + let config = &self.config; + + let (host, port) = get_host_port(config, &dst)?; + let host = host.trim_start_matches('[').trim_end_matches(']'); + + // If the host is already an IP addr (v4 or v6), + // skip resolving the dns and start connecting right away. + let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) { + addrs + } else { + let addrs = resolve(&mut self.resolver, dns::Name::new(host.into())) + .await + .map_err(ConnectError::dns)?; + let addrs = addrs + .map(|mut addr| { + addr.set_port(port); + addr + }) + .collect(); + dns::SocketAddrs::new(addrs) + }; + + let c = ConnectingTcp::new(addrs, config); + + let sock = c.connect().await?; + + if let Err(e) = sock.set_nodelay(config.nodelay) { + warn!("tcp set_nodelay error: {}", e); + } + + Ok(sock) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + let connected = Connected::new(); + if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) { + connected.extra(HttpInfo { remote_addr, local_addr }) + } else { + connected + } + } +} + +impl HttpInfo { + /// Get the remote address of the transport used. + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } + + /// Get the local address of the transport used. + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly + // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct HttpConnecting<R> { + #[pin] + fut: BoxConnecting, + _marker: PhantomData<R>, + } +} + +type ConnectResult = Result<TcpStream, ConnectError>; +type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>; + +impl<R: Resolve> Future for HttpConnecting<R> { + type Output = ConnectResult; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} + +// Not publicly exported (so missing_docs doesn't trigger). +pub struct ConnectError { + msg: Box<str>, + cause: Option<Box<dyn StdError + Send + Sync>>, +} + +impl ConnectError { + fn new<S, E>(msg: S, cause: E) -> ConnectError + where + S: Into<Box<str>>, + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError { + msg: msg.into(), + cause: Some(cause.into()), + } + } + + fn dns<E>(cause: E) -> ConnectError + where + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError::new("dns error", cause) + } + + fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError + where + S: Into<Box<str>>, + E: Into<Box<dyn StdError + Send + Sync>>, + { + move |cause| ConnectError::new(msg, cause) + } +} + +impl fmt::Debug for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref cause) = self.cause { + f.debug_tuple("ConnectError") + .field(&self.msg) + .field(cause) + .finish() + } else { + self.msg.fmt(f) + } + } +} + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.msg)?; + + if let Some(ref cause) = self.cause { + write!(f, ": {}", cause)?; + } + + Ok(()) + } +} + +impl StdError for ConnectError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause.as_ref().map(|e| &**e as _) + } +} + +struct ConnectingTcp<'a> { + preferred: ConnectingTcpRemote, + fallback: Option<ConnectingTcpFallback>, + config: &'a Config, +} + +impl<'a> ConnectingTcp<'a> { + fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self { + if let Some(fallback_timeout) = config.happy_eyeballs_timeout { + let (preferred_addrs, fallback_addrs) = remote_addrs + .split_by_preference(config.local_address_ipv4, config.local_address_ipv6); + if fallback_addrs.is_empty() { + return ConnectingTcp { + preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), + fallback: None, + config, + }; + } + + ConnectingTcp { + preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), + fallback: Some(ConnectingTcpFallback { + delay: tokio::time::sleep(fallback_timeout), + remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout), + }), + config, + } + } else { + ConnectingTcp { + preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout), + fallback: None, + config, + } + } + } +} + +struct ConnectingTcpFallback { + delay: Sleep, + remote: ConnectingTcpRemote, +} + +struct ConnectingTcpRemote { + addrs: dns::SocketAddrs, + connect_timeout: Option<Duration>, +} + +impl ConnectingTcpRemote { + fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self { + let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32)); + + Self { + addrs, + connect_timeout, + } + } +} + +impl ConnectingTcpRemote { + async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> { + let mut err = None; + for addr in &mut self.addrs { + debug!("connecting to {}", addr); + match connect(&addr, config, self.connect_timeout)?.await { + Ok(tcp) => { + debug!("connected to {}", addr); + return Ok(tcp); + } + Err(e) => { + trace!("connect error for {}: {:?}", addr, e); + err = Some(e); + } + } + } + + match err { + Some(e) => Err(e), + None => Err(ConnectError::new( + "tcp connect error", + std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"), + )), + } + } +} + +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr_ipv4: &Option<Ipv4Addr>, + local_addr_ipv6: &Option<Ipv6Addr>, +) -> io::Result<()> { + match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { + (SocketAddr::V4(_), Some(addr), _) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + (SocketAddr::V6(_), _, Some(addr)) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into())?; + } + } + } + + Ok(()) +} + +fn connect( + addr: &SocketAddr, + config: &Config, + connect_timeout: Option<Duration>, +) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> { + // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the + // keepalive timeout, it would be nice to use that instead of socket2, + // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance... + use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type}; + use std::convert::TryInto; + + let domain = Domain::for_address(*addr); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) + .map_err(ConnectError::m("tcp open error"))?; + + // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is + // responsible for ensuring O_NONBLOCK is set. + socket + .set_nonblocking(true) + .map_err(ConnectError::m("tcp set_nonblocking error"))?; + + if let Some(dur) = config.keep_alive_timeout { + let conf = TcpKeepalive::new().with_time(dur); + if let Err(e) = socket.set_tcp_keepalive(&conf) { + warn!("tcp set_keepalive error: {}", e); + } + } + + bind_local_address( + &socket, + addr, + &config.local_address_ipv4, + &config.local_address_ipv6, + ) + .map_err(ConnectError::m("tcp bind local error"))?; + + #[cfg(unix)] + let socket = unsafe { + // Safety: `from_raw_fd` is only safe to call if ownership of the raw + // file descriptor is transferred. Since we call `into_raw_fd` on the + // socket2 socket, it gives up ownership of the fd and will not close + // it, so this is safe. + use std::os::unix::io::{FromRawFd, IntoRawFd}; + TcpSocket::from_raw_fd(socket.into_raw_fd()) + }; + #[cfg(windows)] + let socket = unsafe { + // Safety: `from_raw_socket` is only safe to call if ownership of the raw + // Windows SOCKET is transferred. Since we call `into_raw_socket` on the + // socket2 socket, it gives up ownership of the SOCKET and will not close + // it, so this is safe. + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + TcpSocket::from_raw_socket(socket.into_raw_socket()) + }; + + if config.reuse_address { + if let Err(e) = socket.set_reuseaddr(true) { + warn!("tcp set_reuse_address error: {}", e); + } + } + + if let Some(size) = config.send_buffer_size { + if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) { + warn!("tcp set_buffer_size error: {}", e); + } + } + + if let Some(size) = config.recv_buffer_size { + if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) { + warn!("tcp set_recv_buffer_size error: {}", e); + } + } + + let connect = socket.connect(*addr); + Ok(async move { + match connect_timeout { + Some(dur) => match tokio::time::timeout(dur, connect).await { + Ok(Ok(s)) => Ok(s), + Ok(Err(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), + }, + None => connect.await, + } + .map_err(ConnectError::m("tcp connect error")) + }) +} + +impl ConnectingTcp<'_> { + async fn connect(mut self) -> Result<TcpStream, ConnectError> { + match self.fallback { + None => self.preferred.connect(self.config).await, + Some(mut fallback) => { + let preferred_fut = self.preferred.connect(self.config); + futures_util::pin_mut!(preferred_fut); + + let fallback_fut = fallback.remote.connect(self.config); + futures_util::pin_mut!(fallback_fut); + + let fallback_delay = fallback.delay; + futures_util::pin_mut!(fallback_delay); + + let (result, future) = + match futures_util::future::select(preferred_fut, fallback_delay).await { + Either::Left((result, _fallback_delay)) => { + (result, Either::Right(fallback_fut)) + } + Either::Right(((), preferred_fut)) => { + // Delay is done, start polling both the preferred and the fallback + futures_util::future::select(preferred_fut, fallback_fut) + .await + .factor_first() + } + }; + + if result.is_err() { + // Fallback to the remaining future (could be preferred or fallback) + // if we get an error + future.await + } else { + result + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use ::http::Uri; + + use super::super::sealed::{Connect, ConnectSvc}; + use super::{Config, ConnectError, HttpConnector}; + + async fn connect<C>( + connector: C, + dst: Uri, + ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> + where + C: Connect, + { + connector.connect(super::super::sealed::Internal, dst).await + } + + #[tokio::test] + async fn test_errors_enforce_http() { + let dst = "https://example.domain/foo/bar?baz".parse().unwrap(); + let connector = HttpConnector::new(); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); + } + + #[cfg(any(target_os = "linux", target_os = "macos"))] + fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) { + use std::net::{IpAddr, TcpListener}; + + let mut ip_v4 = None; + let mut ip_v6 = None; + + let ips = pnet_datalink::interfaces() + .into_iter() + .flat_map(|i| i.ips.into_iter().map(|n| n.ip())); + + for ip in ips { + match ip { + IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip), + IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip), + _ => (), + } + + if ip_v4.is_some() && ip_v6.is_some() { + break; + } + } + + (ip_v4, ip_v6) + } + + #[tokio::test] + async fn test_errors_missing_scheme() { + let dst = "example.domain".parse().unwrap(); + let mut connector = HttpConnector::new(); + connector.enforce_http(false); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); + } + + // NOTE: pnet crate that we use in this test doesn't compile on Windows + #[cfg(any(target_os = "linux", target_os = "macos"))] + #[tokio::test] + async fn local_address() { + use std::net::{IpAddr, TcpListener}; + let _ = pretty_env_logger::try_init(); + + let (bind_ip_v4, bind_ip_v6) = get_local_ips(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server4.local_addr().unwrap().port(); + let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap(); + + let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move { + let mut connector = HttpConnector::new(); + + match (bind_ip_v4, bind_ip_v6) { + (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6), + (Some(v4), None) => connector.set_local_address(Some(v4.into())), + (None, Some(v6)) => connector.set_local_address(Some(v6.into())), + _ => unreachable!(), + } + + connect(connector, dst.parse().unwrap()).await.unwrap(); + + let (_, client_addr) = server.accept().unwrap(); + + assert_eq!(client_addr.ip(), expected_ip); + }; + + if let Some(ip) = bind_ip_v4 { + assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await; + } + + if let Some(ip) = bind_ip_v6 { + assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await; + } + } + + #[test] + #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] + fn client_happy_eyeballs() { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener}; + use std::time::{Duration, Instant}; + + use super::dns; + use super::ConnectingTcp; + + let _ = pretty_env_logger::try_init(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server4.local_addr().unwrap(); + let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let local_timeout = Duration::default(); + let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1; + let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1; + let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout) + + Duration::from_millis(250); + + let scenarios = &[ + // Fast primary, without fallback. + (&[local_ipv4_addr()][..], 4, local_timeout, false), + (&[local_ipv6_addr()][..], 6, local_timeout, false), + // Fast primary, with (unused) fallback. + ( + &[local_ipv4_addr(), local_ipv6_addr()][..], + 4, + local_timeout, + false, + ), + ( + &[local_ipv6_addr(), local_ipv4_addr()][..], + 6, + local_timeout, + false, + ), + // Unreachable + fast primary, without fallback. + ( + &[unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + unreachable_v6_timeout, + false, + ), + // Unreachable + fast primary, with (unused) fallback. + ( + &[ + unreachable_ipv4_addr(), + local_ipv4_addr(), + local_ipv6_addr(), + ][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[ + unreachable_ipv6_addr(), + local_ipv6_addr(), + local_ipv4_addr(), + ][..], + 6, + unreachable_v6_timeout, + true, + ), + // Slow primary, with (used) fallback. + ( + &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout, + false, + ), + ( + &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout, + true, + ), + // Slow primary, with (used) unreachable + fast fallback. + ( + &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout + unreachable_v6_timeout, + false, + ), + ( + &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout + unreachable_v4_timeout, + true, + ), + ]; + + // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network. + // Otherwise, connection to "slow" IPv6 address will error-out immediately. + let ipv6_accessible = measure_connect(slow_ipv6_addr()).0; + + for &(hosts, family, timeout, needs_ipv6_access) in scenarios { + if needs_ipv6_access && !ipv6_accessible { + continue; + } + + let (start, stream) = rt + .block_on(async move { + let addrs = hosts + .iter() + .map(|host| (host.clone(), addr.port()).into()) + .collect(); + let cfg = Config { + local_address_ipv4: None, + local_address_ipv6: None, + connect_timeout: None, + keep_alive_timeout: None, + happy_eyeballs_timeout: Some(fallback_timeout), + nodelay: false, + reuse_address: false, + enforce_http: false, + send_buffer_size: None, + recv_buffer_size: None, + }; + let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg); + let start = Instant::now(); + Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?)) + }) + .unwrap(); + let res = if stream.peer_addr().unwrap().is_ipv4() { + 4 + } else { + 6 + }; + let duration = start.elapsed(); + + // Allow actual duration to be +/- 150ms off. + let min_duration = if timeout >= Duration::from_millis(150) { + timeout - Duration::from_millis(150) + } else { + Duration::default() + }; + let max_duration = timeout + Duration::from_millis(150); + + assert_eq!(res, family); + assert!(duration >= min_duration); + assert!(duration <= max_duration); + } + + fn local_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 1).into() + } + + fn local_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into() + } + + fn unreachable_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 2).into() + } + + fn unreachable_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into() + } + + fn slow_ipv4_addr() -> IpAddr { + // RFC 6890 reserved IPv4 address. + Ipv4Addr::new(198, 18, 0, 25).into() + } + + fn slow_ipv6_addr() -> IpAddr { + // RFC 6890 reserved IPv6 address. + Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into() + } + + fn measure_connect(addr: IpAddr) -> (bool, Duration) { + let start = Instant::now(); + let result = + std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1)); + + let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut; + let duration = start.elapsed(); + (reachable, duration) + } + } +} diff --git a/third_party/rust/hyper/src/client/connect/mod.rs b/third_party/rust/hyper/src/client/connect/mod.rs new file mode 100644 index 0000000000..862a0e65c1 --- /dev/null +++ b/third_party/rust/hyper/src/client/connect/mod.rs @@ -0,0 +1,412 @@ +//! Connectors used by the `Client`. +//! +//! This module contains: +//! +//! - A default [`HttpConnector`][] that does DNS resolution and establishes +//! connections over TCP. +//! - Types to build custom connectors. +//! +//! # Connectors +//! +//! A "connector" is a [`Service`][] that takes a [`Uri`][] destination, and +//! its `Response` is some type implementing [`AsyncRead`][], [`AsyncWrite`][], +//! and [`Connection`][]. +//! +//! ## Custom Connectors +//! +//! A simple connector that ignores the `Uri` destination and always returns +//! a TCP connection to the same address could be written like this: +//! +//! ```rust,ignore +//! let connector = tower::service_fn(|_dst| async { +//! tokio::net::TcpStream::connect("127.0.0.1:1337") +//! }) +//! ``` +//! +//! Or, fully written out: +//! +//! ``` +//! # #[cfg(feature = "runtime")] +//! # mod rt { +//! use std::{future::Future, net::SocketAddr, pin::Pin, task::{self, Poll}}; +//! use hyper::{service::Service, Uri}; +//! use tokio::net::TcpStream; +//! +//! #[derive(Clone)] +//! struct LocalConnector; +//! +//! impl Service<Uri> for LocalConnector { +//! type Response = TcpStream; +//! type Error = std::io::Error; +//! // We can't "name" an `async` generated future. +//! type Future = Pin<Box< +//! dyn Future<Output = Result<Self::Response, Self::Error>> + Send +//! >>; +//! +//! fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { +//! // This connector is always ready, but others might not be. +//! Poll::Ready(Ok(())) +//! } +//! +//! fn call(&mut self, _: Uri) -> Self::Future { +//! Box::pin(TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 1337)))) +//! } +//! } +//! # } +//! ``` +//! +//! It's worth noting that for `TcpStream`s, the [`HttpConnector`][] is a +//! better starting place to extend from. +//! +//! Using either of the above connector examples, it can be used with the +//! `Client` like this: +//! +//! ``` +//! # #[cfg(feature = "runtime")] +//! # fn rt () { +//! # let connector = hyper::client::HttpConnector::new(); +//! // let connector = ... +//! +//! let client = hyper::Client::builder() +//! .build::<_, hyper::Body>(connector); +//! # } +//! ``` +//! +//! +//! [`HttpConnector`]: HttpConnector +//! [`Service`]: crate::service::Service +//! [`Uri`]: ::http::Uri +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Connection`]: Connection +use std::fmt; + +use ::http::Extensions; + +cfg_feature! { + #![feature = "tcp"] + + pub use self::http::{HttpConnector, HttpInfo}; + + pub mod dns; + mod http; +} + +cfg_feature! { + #![any(feature = "http1", feature = "http2")] + + pub use self::sealed::Connect; +} + +/// Describes a type returned by a connector. +pub trait Connection { + /// Return metadata describing the connection. + fn connected(&self) -> Connected; +} + +/// Extra information about the connected transport. +/// +/// This can be used to inform recipients about things like if ALPN +/// was used, or if connected to an HTTP proxy. +#[derive(Debug)] +pub struct Connected { + pub(super) alpn: Alpn, + pub(super) is_proxied: bool, + pub(super) extra: Option<Extra>, +} + +pub(super) struct Extra(Box<dyn ExtraInner>); + +#[derive(Clone, Copy, Debug, PartialEq)] +pub(super) enum Alpn { + H2, + None, +} + +impl Connected { + /// Create new `Connected` type with empty metadata. + pub fn new() -> Connected { + Connected { + alpn: Alpn::None, + is_proxied: false, + extra: None, + } + } + + /// Set whether the connected transport is to an HTTP proxy. + /// + /// This setting will affect if HTTP/1 requests written on the transport + /// will have the request-target in absolute-form or origin-form: + /// + /// - When `proxy(false)`: + /// + /// ```http + /// GET /guide HTTP/1.1 + /// ``` + /// + /// - When `proxy(true)`: + /// + /// ```http + /// GET http://hyper.rs/guide HTTP/1.1 + /// ``` + /// + /// Default is `false`. + pub fn proxy(mut self, is_proxied: bool) -> Connected { + self.is_proxied = is_proxied; + self + } + + /// Determines if the connected transport is to an HTTP proxy. + pub fn is_proxied(&self) -> bool { + self.is_proxied + } + + /// Set extra connection information to be set in the extensions of every `Response`. + pub fn extra<T: Clone + Send + Sync + 'static>(mut self, extra: T) -> Connected { + if let Some(prev) = self.extra { + self.extra = Some(Extra(Box::new(ExtraChain(prev.0, extra)))); + } else { + self.extra = Some(Extra(Box::new(ExtraEnvelope(extra)))); + } + self + } + + /// Copies the extra connection information into an `Extensions` map. + pub fn get_extras(&self, extensions: &mut Extensions) { + if let Some(extra) = &self.extra { + extra.set(extensions); + } + } + + /// Set that the connected transport negotiated HTTP/2 as its next protocol. + pub fn negotiated_h2(mut self) -> Connected { + self.alpn = Alpn::H2; + self + } + + /// Determines if the connected transport negotiated HTTP/2 as its next protocol. + pub fn is_negotiated_h2(&self) -> bool { + self.alpn == Alpn::H2 + } + + // Don't public expose that `Connected` is `Clone`, unsure if we want to + // keep that contract... + #[cfg(feature = "http2")] + pub(super) fn clone(&self) -> Connected { + Connected { + alpn: self.alpn.clone(), + is_proxied: self.is_proxied, + extra: self.extra.clone(), + } + } +} + +// ===== impl Extra ===== + +impl Extra { + pub(super) fn set(&self, res: &mut Extensions) { + self.0.set(res); + } +} + +impl Clone for Extra { + fn clone(&self) -> Extra { + Extra(self.0.clone_box()) + } +} + +impl fmt::Debug for Extra { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extra").finish() + } +} + +trait ExtraInner: Send + Sync { + fn clone_box(&self) -> Box<dyn ExtraInner>; + fn set(&self, res: &mut Extensions); +} + +// This indirection allows the `Connected` to have a type-erased "extra" value, +// while that type still knows its inner extra type. This allows the correct +// TypeId to be used when inserting into `res.extensions_mut()`. +#[derive(Clone)] +struct ExtraEnvelope<T>(T); + +impl<T> ExtraInner for ExtraEnvelope<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Extensions) { + res.insert(self.0.clone()); + } +} + +struct ExtraChain<T>(Box<dyn ExtraInner>, T); + +impl<T: Clone> Clone for ExtraChain<T> { + fn clone(&self) -> Self { + ExtraChain(self.0.clone_box(), self.1.clone()) + } +} + +impl<T> ExtraInner for ExtraChain<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Extensions) { + self.0.set(res); + res.insert(self.1.clone()); + } +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pub(super) mod sealed { + use std::error::Error as StdError; + + use ::http::Uri; + use tokio::io::{AsyncRead, AsyncWrite}; + + use super::Connection; + use crate::common::{Future, Unpin}; + + /// Connect to a destination, returning an IO transport. + /// + /// A connector receives a [`Uri`](::http::Uri) and returns a `Future` of the + /// ready connection. + /// + /// # Trait Alias + /// + /// This is really just an *alias* for the `tower::Service` trait, with + /// additional bounds set for convenience *inside* hyper. You don't actually + /// implement this trait, but `tower::Service<Uri>` instead. + // The `Sized` bound is to prevent creating `dyn Connect`, since they cannot + // fit the `Connect` bounds because of the blanket impl for `Service`. + pub trait Connect: Sealed + Sized { + #[doc(hidden)] + type _Svc: ConnectSvc; + #[doc(hidden)] + fn connect(self, internal_only: Internal, dst: Uri) -> <Self::_Svc as ConnectSvc>::Future; + } + + pub trait ConnectSvc { + type Connection: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static; + type Error: Into<Box<dyn StdError + Send + Sync>>; + type Future: Future<Output = Result<Self::Connection, Self::Error>> + Unpin + Send + 'static; + + fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future; + } + + impl<S, T> Connect for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + type _Svc = S; + + fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot<S, Uri> { + crate::service::oneshot(self, dst) + } + } + + impl<S, T> ConnectSvc for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + type Connection = T; + type Error = S::Error; + type Future = crate::service::Oneshot<S, Uri>; + + fn connect(self, _: Internal, dst: Uri) -> Self::Future { + crate::service::oneshot(self, dst) + } + } + + impl<S, T> Sealed for S + where + S: tower_service::Service<Uri, Response = T> + Send, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + } + + pub trait Sealed {} + #[allow(missing_debug_implementations)] + pub struct Internal; +} + +#[cfg(test)] +mod tests { + use super::Connected; + + #[derive(Clone, Debug, PartialEq)] + struct Ex1(usize); + + #[derive(Clone, Debug, PartialEq)] + struct Ex2(&'static str); + + #[derive(Clone, Debug, PartialEq)] + struct Ex3(&'static str); + + #[test] + fn test_connected_extra() { + let c1 = Connected::new().extra(Ex1(41)); + + let mut ex = ::http::Extensions::new(); + + assert_eq!(ex.get::<Ex1>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut ex); + + assert_eq!(ex.get::<Ex1>(), Some(&Ex1(41))); + } + + #[test] + fn test_connected_extra_chain() { + // If a user composes connectors and at each stage, there's "extra" + // info to attach, it shouldn't override the previous extras. + + let c1 = Connected::new() + .extra(Ex1(45)) + .extra(Ex2("zoom")) + .extra(Ex3("pew pew")); + + let mut ex1 = ::http::Extensions::new(); + + assert_eq!(ex1.get::<Ex1>(), None); + assert_eq!(ex1.get::<Ex2>(), None); + assert_eq!(ex1.get::<Ex3>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut ex1); + + assert_eq!(ex1.get::<Ex1>(), Some(&Ex1(45))); + assert_eq!(ex1.get::<Ex2>(), Some(&Ex2("zoom"))); + assert_eq!(ex1.get::<Ex3>(), Some(&Ex3("pew pew"))); + + // Just like extensions, inserting the same type overrides previous type. + let c2 = Connected::new() + .extra(Ex1(33)) + .extra(Ex2("hiccup")) + .extra(Ex1(99)); + + let mut ex2 = ::http::Extensions::new(); + + c2.extra.as_ref().expect("c2 extra").set(&mut ex2); + + assert_eq!(ex2.get::<Ex1>(), Some(&Ex1(99))); + assert_eq!(ex2.get::<Ex2>(), Some(&Ex2("hiccup"))); + } +} diff --git a/third_party/rust/hyper/src/client/dispatch.rs b/third_party/rust/hyper/src/client/dispatch.rs new file mode 100644 index 0000000000..0d70dbccea --- /dev/null +++ b/third_party/rust/hyper/src/client/dispatch.rs @@ -0,0 +1,436 @@ +#[cfg(feature = "http2")] +use std::future::Future; + +use futures_util::FutureExt; +use tokio::sync::{mpsc, oneshot}; + +#[cfg(feature = "http2")] +use crate::common::Pin; +use crate::common::{task, Poll}; + +pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>; +pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>; + +pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) { + let (tx, rx) = mpsc::unbounded_channel(); + let (giver, taker) = want::new(); + let tx = Sender { + buffered_once: false, + giver, + inner: tx, + }; + let rx = Receiver { inner: rx, taker }; + (tx, rx) +} + +/// A bounded sender of requests and callbacks for when responses are ready. +/// +/// While the inner sender is unbounded, the Giver is used to determine +/// if the Receiver is ready for another request. +pub(crate) struct Sender<T, U> { + /// One message is always allowed, even if the Receiver hasn't asked + /// for it yet. This boolean keeps track of whether we've sent one + /// without notice. + buffered_once: bool, + /// The Giver helps watch that the the Receiver side has been polled + /// when the queue is empty. This helps us know when a request and + /// response have been fully processed, and a connection is ready + /// for more. + giver: want::Giver, + /// Actually bounded by the Giver, plus `buffered_once`. + inner: mpsc::UnboundedSender<Envelope<T, U>>, +} + +/// An unbounded version. +/// +/// Cannot poll the Giver, but can still use it to determine if the Receiver +/// has been dropped. However, this version can be cloned. +#[cfg(feature = "http2")] +pub(crate) struct UnboundedSender<T, U> { + /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked. + giver: want::SharedGiver, + inner: mpsc::UnboundedSender<Envelope<T, U>>, +} + +impl<T, U> Sender<T, U> { + pub(crate) fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.giver + .poll_want(cx) + .map_err(|_| crate::Error::new_closed()) + } + + pub(crate) fn is_ready(&self) -> bool { + self.giver.is_wanting() + } + + pub(crate) fn is_closed(&self) -> bool { + self.giver.is_canceled() + } + + fn can_send(&mut self) -> bool { + if self.giver.give() || !self.buffered_once { + // If the receiver is ready *now*, then of course we can send. + // + // If the receiver isn't ready yet, but we don't have anything + // in the channel yet, then allow one message. + self.buffered_once = true; + true + } else { + false + } + } + + pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { + if !self.can_send() { + return Err(val); + } + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } + + pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> { + if !self.can_send() { + return Err(val); + } + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::NoRetry(Some(tx)))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } + + #[cfg(feature = "http2")] + pub(crate) fn unbound(self) -> UnboundedSender<T, U> { + UnboundedSender { + giver: self.giver.shared(), + inner: self.inner, + } + } +} + +#[cfg(feature = "http2")] +impl<T, U> UnboundedSender<T, U> { + pub(crate) fn is_ready(&self) -> bool { + !self.giver.is_canceled() + } + + pub(crate) fn is_closed(&self) -> bool { + self.giver.is_canceled() + } + + pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { + let (tx, rx) = oneshot::channel(); + self.inner + .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) + .map(move |_| rx) + .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0) + } +} + +#[cfg(feature = "http2")] +impl<T, U> Clone for UnboundedSender<T, U> { + fn clone(&self) -> Self { + UnboundedSender { + giver: self.giver.clone(), + inner: self.inner.clone(), + } + } +} + +pub(crate) struct Receiver<T, U> { + inner: mpsc::UnboundedReceiver<Envelope<T, U>>, + taker: want::Taker, +} + +impl<T, U> Receiver<T, U> { + pub(crate) fn poll_recv( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<(T, Callback<T, U>)>> { + match self.inner.poll_recv(cx) { + Poll::Ready(item) => { + Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped"))) + } + Poll::Pending => { + self.taker.want(); + Poll::Pending + } + } + } + + #[cfg(feature = "http1")] + pub(crate) fn close(&mut self) { + self.taker.cancel(); + self.inner.close(); + } + + #[cfg(feature = "http1")] + pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> { + match self.inner.recv().now_or_never() { + Some(Some(mut env)) => env.0.take(), + _ => None, + } + } +} + +impl<T, U> Drop for Receiver<T, U> { + fn drop(&mut self) { + // Notify the giver about the closure first, before dropping + // the mpsc::Receiver. + self.taker.cancel(); + } +} + +struct Envelope<T, U>(Option<(T, Callback<T, U>)>); + +impl<T, U> Drop for Envelope<T, U> { + fn drop(&mut self) { + if let Some((val, cb)) = self.0.take() { + cb.send(Err(( + crate::Error::new_canceled().with("connection closed"), + Some(val), + ))); + } + } +} + +pub(crate) enum Callback<T, U> { + Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>), + NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>), +} + +impl<T, U> Drop for Callback<T, U> { + fn drop(&mut self) { + // FIXME(nox): What errors do we want here? + let error = crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() { + "user code panicked" + } else { + "runtime dropped the dispatch task" + }); + + match self { + Callback::Retry(tx) => { + if let Some(tx) = tx.take() { + let _ = tx.send(Err((error, None))); + } + } + Callback::NoRetry(tx) => { + if let Some(tx) = tx.take() { + let _ = tx.send(Err(error)); + } + } + } + } +} + +impl<T, U> Callback<T, U> { + #[cfg(feature = "http2")] + pub(crate) fn is_canceled(&self) -> bool { + match *self { + Callback::Retry(Some(ref tx)) => tx.is_closed(), + Callback::NoRetry(Some(ref tx)) => tx.is_closed(), + _ => unreachable!(), + } + } + + pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> { + match *self { + Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx), + Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx), + _ => unreachable!(), + } + } + + pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) { + match self { + Callback::Retry(ref mut tx) => { + let _ = tx.take().unwrap().send(val); + } + Callback::NoRetry(ref mut tx) => { + let _ = tx.take().unwrap().send(val.map_err(|e| e.0)); + } + } + } + + #[cfg(feature = "http2")] + pub(crate) async fn send_when( + self, + mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, + ) { + use futures_util::future; + use tracing::trace; + + let mut cb = Some(self); + + // "select" on this callback being canceled, and the future completing + future::poll_fn(move |cx| { + match Pin::new(&mut when).poll(cx) { + Poll::Ready(Ok(res)) => { + cb.take().expect("polled after complete").send(Ok(res)); + Poll::Ready(()) + } + Poll::Pending => { + // check if the callback is canceled + ready!(cb.as_mut().unwrap().poll_canceled(cx)); + trace!("send_when canceled"); + Poll::Ready(()) + } + Poll::Ready(Err(err)) => { + cb.take().expect("polled after complete").send(Err(err)); + Poll::Ready(()) + } + } + }) + .await + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + extern crate test; + + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use super::{channel, Callback, Receiver}; + + #[derive(Debug)] + struct Custom(i32); + + impl<T, U> Future for Receiver<T, U> { + type Output = Option<(T, Callback<T, U>)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.poll_recv(cx) + } + } + + /// Helper to check if the future is ready after polling once. + struct PollOnce<'a, F>(&'a mut F); + + impl<F, T> Future for PollOnce<'_, F> + where + F: Future<Output = T> + Unpin, + { + type Output = Option<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(_) => Poll::Ready(Some(())), + Poll::Pending => Poll::Ready(None), + } + } + } + + #[tokio::test] + async fn drop_receiver_sends_cancel_errors() { + let _ = pretty_env_logger::try_init(); + + let (mut tx, mut rx) = channel::<Custom, ()>(); + + // must poll once for try_send to succeed + assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); + + let promise = tx.try_send(Custom(43)).unwrap(); + drop(rx); + + let fulfilled = promise.await; + let err = fulfilled + .expect("fulfilled") + .expect_err("promise should error"); + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + e => panic!("expected Error::Cancel(_), found {:?}", e), + } + } + + #[tokio::test] + async fn sender_checks_for_want_on_send() { + let (mut tx, mut rx) = channel::<Custom, ()>(); + + // one is allowed to buffer, second is rejected + let _ = tx.try_send(Custom(1)).expect("1 buffered"); + tx.try_send(Custom(2)).expect_err("2 not ready"); + + assert!(PollOnce(&mut rx).await.is_some(), "rx once"); + + // Even though 1 has been popped, only 1 could be buffered for the + // lifetime of the channel. + tx.try_send(Custom(2)).expect_err("2 still not ready"); + + assert!(PollOnce(&mut rx).await.is_none(), "rx empty"); + + let _ = tx.try_send(Custom(2)).expect("2 ready"); + } + + #[cfg(feature = "http2")] + #[test] + fn unbounded_sender_doesnt_bound_on_want() { + let (tx, rx) = channel::<Custom, ()>(); + let mut tx = tx.unbound(); + + let _ = tx.try_send(Custom(1)).unwrap(); + let _ = tx.try_send(Custom(2)).unwrap(); + let _ = tx.try_send(Custom(3)).unwrap(); + + drop(rx); + + let _ = tx.try_send(Custom(4)).unwrap_err(); + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_throughput(b: &mut test::Bencher) { + use crate::{Body, Request, Response}; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>(); + + b.iter(move || { + let _ = tx.send(Request::default()).unwrap(); + rt.block_on(async { + loop { + let poll_once = PollOnce(&mut rx); + let opt = poll_once.await; + if opt.is_none() { + break; + } + } + }); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_not_ready(b: &mut test::Bencher) { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let (_tx, mut rx) = channel::<i32, ()>(); + b.iter(move || { + rt.block_on(async { + let poll_once = PollOnce(&mut rx); + assert!(poll_once.await.is_none()); + }); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn giver_queue_cancel(b: &mut test::Bencher) { + let (_tx, mut rx) = channel::<i32, ()>(); + + b.iter(move || { + rx.taker.cancel(); + }) + } +} diff --git a/third_party/rust/hyper/src/client/mod.rs b/third_party/rust/hyper/src/client/mod.rs new file mode 100644 index 0000000000..734bda8819 --- /dev/null +++ b/third_party/rust/hyper/src/client/mod.rs @@ -0,0 +1,68 @@ +//! HTTP Client +//! +//! There are two levels of APIs provided for construct HTTP clients: +//! +//! - The higher-level [`Client`](Client) type. +//! - The lower-level [`conn`](conn) module. +//! +//! # Client +//! +//! The [`Client`](Client) is the main way to send HTTP requests to a server. +//! The default `Client` provides these things on top of the lower-level API: +//! +//! - A default **connector**, able to resolve hostnames and connect to +//! destinations over plain-text TCP. +//! - A **pool** of existing connections, allowing better performance when +//! making multiple requests to the same hostname. +//! - Automatic setting of the `Host` header, based on the request `Uri`. +//! - Automatic request **retries** when a pooled connection is closed by the +//! server before any bytes have been written. +//! +//! Many of these features can configured, by making use of +//! [`Client::builder`](Client::builder). +//! +//! ## Example +//! +//! For a small example program simply fetching a URL, take a look at the +//! [full client example](https://github.com/hyperium/hyper/blob/master/examples/client.rs). +//! +//! ``` +//! # #[cfg(all(feature = "tcp", feature = "client", any(feature = "http1", feature = "http2")))] +//! # async fn fetch_httpbin() -> hyper::Result<()> { +//! use hyper::{body::HttpBody as _, Client, Uri}; +//! +//! let client = Client::new(); +//! +//! // Make a GET /ip to 'http://httpbin.org' +//! let res = client.get(Uri::from_static("http://httpbin.org/ip")).await?; +//! +//! // And then, if the request gets a response... +//! println!("status: {}", res.status()); +//! +//! // Concatenate the body stream into a single buffer... +//! let buf = hyper::body::to_bytes(res).await?; +//! +//! println!("body: {:?}", buf); +//! # Ok(()) +//! # } +//! # fn main () {} +//! ``` + +#[cfg(feature = "tcp")] +pub use self::connect::HttpConnector; + +pub mod connect; +#[cfg(all(test, feature = "runtime"))] +mod tests; + +cfg_feature! { + #![any(feature = "http1", feature = "http2")] + + pub use self::client::{Builder, Client, ResponseFuture}; + + mod client; + pub mod conn; + pub(super) mod dispatch; + mod pool; + pub mod service; +} diff --git a/third_party/rust/hyper/src/client/pool.rs b/third_party/rust/hyper/src/client/pool.rs new file mode 100644 index 0000000000..b9772d688d --- /dev/null +++ b/third_party/rust/hyper/src/client/pool.rs @@ -0,0 +1,1044 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::error::Error as StdError; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Mutex, Weak}; + +#[cfg(not(feature = "runtime"))] +use std::time::{Duration, Instant}; + +use futures_channel::oneshot; +#[cfg(feature = "runtime")] +use tokio::time::{Duration, Instant, Interval}; +use tracing::{debug, trace}; + +use super::client::Ver; +use crate::common::{exec::Exec, task, Future, Pin, Poll, Unpin}; + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Pool<T> { + // If the pool is disabled, this is None. + inner: Option<Arc<Mutex<PoolInner<T>>>>, +} + +// Before using a pooled connection, make sure the sender is not dead. +// +// This is a trait to allow the `client::pool::tests` to work for `i32`. +// +// See https://github.com/hyperium/hyper/issues/1429 +pub(super) trait Poolable: Unpin + Send + Sized + 'static { + fn is_open(&self) -> bool; + /// Reserve this connection. + /// + /// Allows for HTTP/2 to return a shared reservation. + fn reserve(self) -> Reservation<Self>; + fn can_share(&self) -> bool; +} + +/// When checking out a pooled connection, it might be that the connection +/// only supports a single reservation, or it might be usable for many. +/// +/// Specifically, HTTP/1 requires a unique reservation, but HTTP/2 can be +/// used for multiple requests. +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) enum Reservation<T> { + /// This connection could be used multiple times, the first one will be + /// reinserted into the `idle` pool, and the second will be given to + /// the `Checkout`. + #[cfg(feature = "http2")] + Shared(T, T), + /// This connection requires unique access. It will be returned after + /// use is complete. + Unique(T), +} + +/// Simple type alias in case the key type needs to be adjusted. +pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>; + +struct PoolInner<T> { + // A flag that a connection is being established, and the connection + // should be shared. This prevents making multiple HTTP/2 connections + // to the same host. + connecting: HashSet<Key>, + // These are internal Conns sitting in the event loop in the KeepAlive + // state, waiting to receive a new Request to send on the socket. + idle: HashMap<Key, Vec<Idle<T>>>, + max_idle_per_host: usize, + // These are outstanding Checkouts that are waiting for a socket to be + // able to send a Request one. This is used when "racing" for a new + // connection. + // + // The Client starts 2 tasks, 1 to connect a new socket, and 1 to wait + // for the Pool to receive an idle Conn. When a Conn becomes idle, + // this list is checked for any parked Checkouts, and tries to notify + // them that the Conn could be used instead of waiting for a brand new + // connection. + waiters: HashMap<Key, VecDeque<oneshot::Sender<T>>>, + // A oneshot channel is used to allow the interval to be notified when + // the Pool completely drops. That way, the interval can cancel immediately. + #[cfg(feature = "runtime")] + idle_interval_ref: Option<oneshot::Sender<crate::common::Never>>, + #[cfg(feature = "runtime")] + exec: Exec, + timeout: Option<Duration>, +} + +// This is because `Weak::new()` *allocates* space for `T`, even if it +// doesn't need it! +struct WeakOpt<T>(Option<Weak<T>>); + +#[derive(Clone, Copy, Debug)] +pub(super) struct Config { + pub(super) idle_timeout: Option<Duration>, + pub(super) max_idle_per_host: usize, +} + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + self.max_idle_per_host > 0 + } +} + +impl<T> Pool<T> { + pub(super) fn new(config: Config, __exec: &Exec) -> Pool<T> { + let inner = if config.is_enabled() { + Some(Arc::new(Mutex::new(PoolInner { + connecting: HashSet::new(), + idle: HashMap::new(), + #[cfg(feature = "runtime")] + idle_interval_ref: None, + max_idle_per_host: config.max_idle_per_host, + waiters: HashMap::new(), + #[cfg(feature = "runtime")] + exec: __exec.clone(), + timeout: config.idle_timeout, + }))) + } else { + None + }; + + Pool { inner } + } + + fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + #[cfg(test)] + pub(super) fn no_timer(&self) { + // Prevent an actual interval from being created for this pool... + #[cfg(feature = "runtime")] + { + let mut inner = self.inner.as_ref().unwrap().lock().unwrap(); + assert!(inner.idle_interval_ref.is_none(), "timer already spawned"); + let (tx, _) = oneshot::channel(); + inner.idle_interval_ref = Some(tx); + } + } +} + +impl<T: Poolable> Pool<T> { + /// Returns a `Checkout` which is a future that resolves if an idle + /// connection becomes available. + pub(super) fn checkout(&self, key: Key) -> Checkout<T> { + Checkout { + key, + pool: self.clone(), + waiter: None, + } + } + + /// Ensure that there is only ever 1 connecting task for HTTP/2 + /// connections. This does nothing for HTTP/1. + pub(super) fn connecting(&self, key: &Key, ver: Ver) -> Option<Connecting<T>> { + if ver == Ver::Http2 { + if let Some(ref enabled) = self.inner { + let mut inner = enabled.lock().unwrap(); + return if inner.connecting.insert(key.clone()) { + let connecting = Connecting { + key: key.clone(), + pool: WeakOpt::downgrade(enabled), + }; + Some(connecting) + } else { + trace!("HTTP/2 connecting already in progress for {:?}", key); + None + }; + } + } + + // else + Some(Connecting { + key: key.clone(), + // in HTTP/1's case, there is never a lock, so we don't + // need to do anything in Drop. + pool: WeakOpt::none(), + }) + } + + #[cfg(test)] + fn locked(&self) -> std::sync::MutexGuard<'_, PoolInner<T>> { + self.inner.as_ref().expect("enabled").lock().expect("lock") + } + + /* Used in client/tests.rs... + #[cfg(feature = "runtime")] + #[cfg(test)] + pub(super) fn h1_key(&self, s: &str) -> Key { + Arc::new(s.to_string()) + } + + #[cfg(feature = "runtime")] + #[cfg(test)] + pub(super) fn idle_count(&self, key: &Key) -> usize { + self + .locked() + .idle + .get(key) + .map(|list| list.len()) + .unwrap_or(0) + } + */ + + pub(super) fn pooled( + &self, + #[cfg_attr(not(feature = "http2"), allow(unused_mut))] mut connecting: Connecting<T>, + value: T, + ) -> Pooled<T> { + let (value, pool_ref) = if let Some(ref enabled) = self.inner { + match value.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_insert, to_return) => { + let mut inner = enabled.lock().unwrap(); + inner.put(connecting.key.clone(), to_insert, enabled); + // Do this here instead of Drop for Connecting because we + // already have a lock, no need to lock the mutex twice. + inner.connected(&connecting.key); + // prevent the Drop of Connecting from repeating inner.connected() + connecting.pool = WeakOpt::none(); + + // Shared reservations don't need a reference to the pool, + // since the pool always keeps a copy. + (to_return, WeakOpt::none()) + } + Reservation::Unique(value) => { + // Unique reservations must take a reference to the pool + // since they hope to reinsert once the reservation is + // completed + (value, WeakOpt::downgrade(enabled)) + } + } + } else { + // If pool is not enabled, skip all the things... + + // The Connecting should have had no pool ref + debug_assert!(connecting.pool.upgrade().is_none()); + + (value, WeakOpt::none()) + }; + Pooled { + key: connecting.key.clone(), + is_reused: false, + pool: pool_ref, + value: Some(value), + } + } + + fn reuse(&self, key: &Key, value: T) -> Pooled<T> { + debug!("reuse idle connection for {:?}", key); + // TODO: unhack this + // In Pool::pooled(), which is used for inserting brand new connections, + // there's some code that adjusts the pool reference taken depending + // on if the Reservation can be shared or is unique. By the time + // reuse() is called, the reservation has already been made, and + // we just have the final value, without knowledge of if this is + // unique or shared. So, the hack is to just assume Ver::Http2 means + // shared... :( + let mut pool_ref = WeakOpt::none(); + if !value.can_share() { + if let Some(ref enabled) = self.inner { + pool_ref = WeakOpt::downgrade(enabled); + } + } + + Pooled { + is_reused: true, + key: key.clone(), + pool: pool_ref, + value: Some(value), + } + } +} + +/// Pop off this list, looking for a usable connection that hasn't expired. +struct IdlePopper<'a, T> { + key: &'a Key, + list: &'a mut Vec<Idle<T>>, +} + +impl<'a, T: Poolable + 'a> IdlePopper<'a, T> { + fn pop(self, expiration: &Expiration) -> Option<Idle<T>> { + while let Some(entry) = self.list.pop() { + // If the connection has been closed, or is older than our idle + // timeout, simply drop it and keep looking... + if !entry.value.is_open() { + trace!("removing closed connection for {:?}", self.key); + continue; + } + // TODO: Actually, since the `idle` list is pushed to the end always, + // that would imply that if *this* entry is expired, then anything + // "earlier" in the list would *have* to be expired also... Right? + // + // In that case, we could just break out of the loop and drop the + // whole list... + if expiration.expires(entry.idle_at) { + trace!("removing expired connection for {:?}", self.key); + continue; + } + + let value = match entry.value.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_reinsert, to_checkout) => { + self.list.push(Idle { + idle_at: Instant::now(), + value: to_reinsert, + }); + to_checkout + } + Reservation::Unique(unique) => unique, + }; + + return Some(Idle { + idle_at: entry.idle_at, + value, + }); + } + + None + } +} + +impl<T: Poolable> PoolInner<T> { + fn put(&mut self, key: Key, value: T, __pool_ref: &Arc<Mutex<PoolInner<T>>>) { + if value.can_share() && self.idle.contains_key(&key) { + trace!("put; existing idle HTTP/2 connection for {:?}", key); + return; + } + trace!("put; add idle connection for {:?}", key); + let mut remove_waiters = false; + let mut value = Some(value); + if let Some(waiters) = self.waiters.get_mut(&key) { + while let Some(tx) = waiters.pop_front() { + if !tx.is_canceled() { + let reserved = value.take().expect("value already sent"); + let reserved = match reserved.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_keep, to_send) => { + value = Some(to_keep); + to_send + } + Reservation::Unique(uniq) => uniq, + }; + match tx.send(reserved) { + Ok(()) => { + if value.is_none() { + break; + } else { + continue; + } + } + Err(e) => { + value = Some(e); + } + } + } + + trace!("put; removing canceled waiter for {:?}", key); + } + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(&key); + } + + match value { + Some(value) => { + // borrow-check scope... + { + let idle_list = self.idle.entry(key.clone()).or_insert_with(Vec::new); + if self.max_idle_per_host <= idle_list.len() { + trace!("max idle per host for {:?}, dropping connection", key); + return; + } + + debug!("pooling idle connection for {:?}", key); + idle_list.push(Idle { + value, + idle_at: Instant::now(), + }); + } + + #[cfg(feature = "runtime")] + { + self.spawn_idle_interval(__pool_ref); + } + } + None => trace!("put; found waiter for {:?}", key), + } + } + + /// A `Connecting` task is complete. Not necessarily successfully, + /// but the lock is going away, so clean up. + fn connected(&mut self, key: &Key) { + let existed = self.connecting.remove(key); + debug_assert!(existed, "Connecting dropped, key not in pool.connecting"); + // cancel any waiters. if there are any, it's because + // this Connecting task didn't complete successfully. + // those waiters would never receive a connection. + self.waiters.remove(key); + } + + #[cfg(feature = "runtime")] + fn spawn_idle_interval(&mut self, pool_ref: &Arc<Mutex<PoolInner<T>>>) { + let (dur, rx) = { + if self.idle_interval_ref.is_some() { + return; + } + + if let Some(dur) = self.timeout { + let (tx, rx) = oneshot::channel(); + self.idle_interval_ref = Some(tx); + (dur, rx) + } else { + return; + } + }; + + let interval = IdleTask { + interval: tokio::time::interval(dur), + pool: WeakOpt::downgrade(pool_ref), + pool_drop_notifier: rx, + }; + + self.exec.execute(interval); + } +} + +impl<T> PoolInner<T> { + /// Any `FutureResponse`s that were created will have made a `Checkout`, + /// and possibly inserted into the pool that it is waiting for an idle + /// connection. If a user ever dropped that future, we need to clean out + /// those parked senders. + fn clean_waiters(&mut self, key: &Key) { + let mut remove_waiters = false; + if let Some(waiters) = self.waiters.get_mut(key) { + waiters.retain(|tx| !tx.is_canceled()); + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(key); + } + } +} + +#[cfg(feature = "runtime")] +impl<T: Poolable> PoolInner<T> { + /// This should *only* be called by the IdleTask + fn clear_expired(&mut self) { + let dur = self.timeout.expect("interval assumes timeout"); + + let now = Instant::now(); + //self.last_idle_check_at = now; + + self.idle.retain(|key, values| { + values.retain(|entry| { + if !entry.value.is_open() { + trace!("idle interval evicting closed for {:?}", key); + return false; + } + + // Avoid `Instant::sub` to avoid issues like rust-lang/rust#86470. + if now.saturating_duration_since(entry.idle_at) > dur { + trace!("idle interval evicting expired for {:?}", key); + return false; + } + + // Otherwise, keep this value... + true + }); + + // returning false evicts this key/val + !values.is_empty() + }); + } +} + +impl<T> Clone for Pool<T> { + fn clone(&self) -> Pool<T> { + Pool { + inner: self.inner.clone(), + } + } +} + +/// A wrapped poolable value that tries to reinsert to the Pool on Drop. +// Note: The bounds `T: Poolable` is needed for the Drop impl. +pub(super) struct Pooled<T: Poolable> { + value: Option<T>, + is_reused: bool, + key: Key, + pool: WeakOpt<Mutex<PoolInner<T>>>, +} + +impl<T: Poolable> Pooled<T> { + pub(super) fn is_reused(&self) -> bool { + self.is_reused + } + + pub(super) fn is_pool_enabled(&self) -> bool { + self.pool.0.is_some() + } + + fn as_ref(&self) -> &T { + self.value.as_ref().expect("not dropped") + } + + fn as_mut(&mut self) -> &mut T { + self.value.as_mut().expect("not dropped") + } +} + +impl<T: Poolable> Deref for Pooled<T> { + type Target = T; + fn deref(&self) -> &T { + self.as_ref() + } +} + +impl<T: Poolable> DerefMut for Pooled<T> { + fn deref_mut(&mut self) -> &mut T { + self.as_mut() + } +} + +impl<T: Poolable> Drop for Pooled<T> { + fn drop(&mut self) { + if let Some(value) = self.value.take() { + if !value.is_open() { + // If we *already* know the connection is done here, + // it shouldn't be re-inserted back into the pool. + return; + } + + if let Some(pool) = self.pool.upgrade() { + if let Ok(mut inner) = pool.lock() { + inner.put(self.key.clone(), value, &pool); + } + } else if !value.can_share() { + trace!("pool dropped, dropping pooled ({:?})", self.key); + } + // Ver::Http2 is already in the Pool (or dead), so we wouldn't + // have an actual reference to the Pool. + } + } +} + +impl<T: Poolable> fmt::Debug for Pooled<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pooled").field("key", &self.key).finish() + } +} + +struct Idle<T> { + idle_at: Instant, + value: T, +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Checkout<T> { + key: Key, + pool: Pool<T>, + waiter: Option<oneshot::Receiver<T>>, +} + +#[derive(Debug)] +pub(super) struct CheckoutIsClosedError; + +impl StdError for CheckoutIsClosedError {} + +impl fmt::Display for CheckoutIsClosedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("checked out connection was closed") + } +} + +impl<T: Poolable> Checkout<T> { + fn poll_waiter( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<Pooled<T>>>> { + if let Some(mut rx) = self.waiter.take() { + match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(value)) => { + if value.is_open() { + Poll::Ready(Some(Ok(self.pool.reuse(&self.key, value)))) + } else { + Poll::Ready(Some(Err( + crate::Error::new_canceled().with(CheckoutIsClosedError) + ))) + } + } + Poll::Pending => { + self.waiter = Some(rx); + Poll::Pending + } + Poll::Ready(Err(_canceled)) => Poll::Ready(Some(Err( + crate::Error::new_canceled().with("request has been canceled") + ))), + } + } else { + Poll::Ready(None) + } + } + + fn checkout(&mut self, cx: &mut task::Context<'_>) -> Option<Pooled<T>> { + let entry = { + let mut inner = self.pool.inner.as_ref()?.lock().unwrap(); + let expiration = Expiration::new(inner.timeout); + let maybe_entry = inner.idle.get_mut(&self.key).and_then(|list| { + trace!("take? {:?}: expiration = {:?}", self.key, expiration.0); + // A block to end the mutable borrow on list, + // so the map below can check is_empty() + { + let popper = IdlePopper { + key: &self.key, + list, + }; + popper.pop(&expiration) + } + .map(|e| (e, list.is_empty())) + }); + + let (entry, empty) = if let Some((e, empty)) = maybe_entry { + (Some(e), empty) + } else { + // No entry found means nuke the list for sure. + (None, true) + }; + if empty { + //TODO: This could be done with the HashMap::entry API instead. + inner.idle.remove(&self.key); + } + + if entry.is_none() && self.waiter.is_none() { + let (tx, mut rx) = oneshot::channel(); + trace!("checkout waiting for idle connection: {:?}", self.key); + inner + .waiters + .entry(self.key.clone()) + .or_insert_with(VecDeque::new) + .push_back(tx); + + // register the waker with this oneshot + assert!(Pin::new(&mut rx).poll(cx).is_pending()); + self.waiter = Some(rx); + } + + entry + }; + + entry.map(|e| self.pool.reuse(&self.key, e.value)) + } +} + +impl<T: Poolable> Future for Checkout<T> { + type Output = crate::Result<Pooled<T>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + if let Some(pooled) = ready!(self.poll_waiter(cx)?) { + return Poll::Ready(Ok(pooled)); + } + + if let Some(pooled) = self.checkout(cx) { + Poll::Ready(Ok(pooled)) + } else if !self.pool.is_enabled() { + Poll::Ready(Err(crate::Error::new_canceled().with("pool is disabled"))) + } else { + // There's a new waiter, already registered in self.checkout() + debug_assert!(self.waiter.is_some()); + Poll::Pending + } + } +} + +impl<T> Drop for Checkout<T> { + fn drop(&mut self) { + if self.waiter.take().is_some() { + trace!("checkout dropped for {:?}", self.key); + if let Some(Ok(mut inner)) = self.pool.inner.as_ref().map(|i| i.lock()) { + inner.clean_waiters(&self.key); + } + } + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub(super) struct Connecting<T: Poolable> { + key: Key, + pool: WeakOpt<Mutex<PoolInner<T>>>, +} + +impl<T: Poolable> Connecting<T> { + pub(super) fn alpn_h2(self, pool: &Pool<T>) -> Option<Self> { + debug_assert!( + self.pool.0.is_none(), + "Connecting::alpn_h2 but already Http2" + ); + + pool.connecting(&self.key, Ver::Http2) + } +} + +impl<T: Poolable> Drop for Connecting<T> { + fn drop(&mut self) { + if let Some(pool) = self.pool.upgrade() { + // No need to panic on drop, that could abort! + if let Ok(mut inner) = pool.lock() { + inner.connected(&self.key); + } + } + } +} + +struct Expiration(Option<Duration>); + +impl Expiration { + fn new(dur: Option<Duration>) -> Expiration { + Expiration(dur) + } + + fn expires(&self, instant: Instant) -> bool { + match self.0 { + // Avoid `Instant::elapsed` to avoid issues like rust-lang/rust#86470. + Some(timeout) => Instant::now().saturating_duration_since(instant) > timeout, + None => false, + } + } +} + +#[cfg(feature = "runtime")] +pin_project_lite::pin_project! { + struct IdleTask<T> { + #[pin] + interval: Interval, + pool: WeakOpt<Mutex<PoolInner<T>>>, + // This allows the IdleTask to be notified as soon as the entire + // Pool is fully dropped, and shutdown. This channel is never sent on, + // but Err(Canceled) will be received when the Pool is dropped. + #[pin] + pool_drop_notifier: oneshot::Receiver<crate::common::Never>, + } +} + +#[cfg(feature = "runtime")] +impl<T: Poolable + 'static> Future for IdleTask<T> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + loop { + match this.pool_drop_notifier.as_mut().poll(cx) { + Poll::Ready(Ok(n)) => match n {}, + Poll::Pending => (), + Poll::Ready(Err(_canceled)) => { + trace!("pool closed, canceling idle interval"); + return Poll::Ready(()); + } + } + + ready!(this.interval.as_mut().poll_tick(cx)); + + if let Some(inner) = this.pool.upgrade() { + if let Ok(mut inner) = inner.lock() { + trace!("idle interval checking for expired"); + inner.clear_expired(); + continue; + } + } + return Poll::Ready(()); + } + } +} + +impl<T> WeakOpt<T> { + fn none() -> Self { + WeakOpt(None) + } + + fn downgrade(arc: &Arc<T>) -> Self { + WeakOpt(Some(Arc::downgrade(arc))) + } + + fn upgrade(&self) -> Option<Arc<T>> { + self.0.as_ref().and_then(Weak::upgrade) + } +} + +#[cfg(test)] +mod tests { + use std::task::Poll; + use std::time::Duration; + + use super::{Connecting, Key, Pool, Poolable, Reservation, WeakOpt}; + use crate::common::{exec::Exec, task, Future, Pin}; + + /// Test unique reservations. + #[derive(Debug, PartialEq, Eq)] + struct Uniq<T>(T); + + impl<T: Send + 'static + Unpin> Poolable for Uniq<T> { + fn is_open(&self) -> bool { + true + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + fn c<T: Poolable>(key: Key) -> Connecting<T> { + Connecting { + key, + pool: WeakOpt::none(), + } + } + + fn host_key(s: &str) -> Key { + (http::uri::Scheme::HTTP, s.parse().expect("host key")) + } + + fn pool_no_timer<T>() -> Pool<T> { + pool_max_idle_no_timer(::std::usize::MAX) + } + + fn pool_max_idle_no_timer<T>(max_idle: usize) -> Pool<T> { + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(100)), + max_idle_per_host: max_idle, + }, + &Exec::Default, + ); + pool.no_timer(); + pool + } + + #[tokio::test] + async fn test_pool_checkout_smoke() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + + match pool.checkout(key).await { + Ok(pooled) => assert_eq!(*pooled, Uniq(41)), + Err(_) => panic!("not ready"), + }; + } + + /// Helper to check if the future is ready after polling once. + struct PollOnce<'a, F>(&'a mut F); + + impl<F, T, U> Future for PollOnce<'_, F> + where + F: Future<Output = Result<T, U>> + Unpin, + { + type Output = Option<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Some(())), + Poll::Ready(Err(_)) => Poll::Ready(Some(())), + Poll::Pending => Poll::Ready(None), + } + } + } + + #[tokio::test] + async fn test_pool_checkout_returns_none_if_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + tokio::time::sleep(pool.locked().timeout.unwrap()).await; + let mut checkout = pool.checkout(key); + let poll_once = PollOnce(&mut checkout); + let is_not_ready = poll_once.await.is_none(); + assert!(is_not_ready); + } + + #[cfg(feature = "runtime")] + #[tokio::test] + async fn test_pool_checkout_removes_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + tokio::time::sleep(pool.locked().timeout.unwrap()).await; + + let mut checkout = pool.checkout(key.clone()); + let poll_once = PollOnce(&mut checkout); + // checkout.await should clean out the expired + poll_once.await; + assert!(pool.locked().idle.get(&key).is_none()); + } + + #[test] + fn test_pool_max_idle_per_host() { + let pool = pool_max_idle_no_timer(2); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + // pooled and dropped 3, max_idle should only allow 2 + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(2) + ); + } + + #[cfg(feature = "runtime")] + #[tokio::test] + async fn test_pool_timer_removes_expired() { + let _ = pretty_env_logger::try_init(); + tokio::time::pause(); + + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(10)), + max_idle_per_host: std::usize::MAX, + }, + &Exec::Default, + ); + + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + + // Let the timer tick passed the expiration... + tokio::time::advance(Duration::from_millis(30)).await; + // Yield so the Interval can reap... + tokio::task::yield_now().await; + + assert!(pool.locked().idle.get(&key).is_none()); + } + + #[tokio::test] + async fn test_pool_checkout_task_unparked() { + use futures_util::future::join; + use futures_util::FutureExt; + + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + let checkout = join(pool.checkout(key), async { + // the checkout future will park first, + // and then this lazy future will be polled, which will insert + // the pooled back into the pool + // + // this test makes sure that doing so will unpark the checkout + drop(pooled); + }) + .map(|(entry, _)| entry); + + assert_eq!(*checkout.await.unwrap(), Uniq(41)); + } + + #[tokio::test] + async fn test_pool_checkout_drop_cleans_up_waiters() { + let pool = pool_no_timer::<Uniq<i32>>(); + let key = host_key("foo"); + + let mut checkout1 = pool.checkout(key.clone()); + let mut checkout2 = pool.checkout(key.clone()); + + let poll_once1 = PollOnce(&mut checkout1); + let poll_once2 = PollOnce(&mut checkout2); + + // first poll needed to get into Pool's parked + poll_once1.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + poll_once2.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 2); + + // on drop, clean up Pool + drop(checkout1); + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + + drop(checkout2); + assert!(pool.locked().waiters.get(&key).is_none()); + } + + #[derive(Debug)] + struct CanClose { + #[allow(unused)] + val: i32, + closed: bool, + } + + impl Poolable for CanClose { + fn is_open(&self) -> bool { + !self.closed + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + #[test] + fn pooled_drop_if_closed_doesnt_reinsert() { + let pool = pool_no_timer(); + let key = host_key("foo"); + pool.pooled( + c(key.clone()), + CanClose { + val: 57, + closed: true, + }, + ); + + assert!(!pool.locked().idle.contains_key(&key)); + } +} diff --git a/third_party/rust/hyper/src/client/service.rs b/third_party/rust/hyper/src/client/service.rs new file mode 100644 index 0000000000..406f61edc9 --- /dev/null +++ b/third_party/rust/hyper/src/client/service.rs @@ -0,0 +1,89 @@ +//! Utilities used to interact with the Tower ecosystem. +//! +//! This module provides `Connect` which hook-ins into the Tower ecosystem. + +use std::error::Error as StdError; +use std::future::Future; +use std::marker::PhantomData; + +use tracing::debug; + +use super::conn::{Builder, SendRequest}; +use crate::{ + body::HttpBody, + common::{task, Pin, Poll}, + service::{MakeConnection, Service}, +}; + +/// Creates a connection via `SendRequest`. +/// +/// This accepts a `hyper::client::conn::Builder` and provides +/// a `MakeService` implementation to create connections from some +/// target `T`. +#[derive(Debug)] +pub struct Connect<C, B, T> { + inner: C, + builder: Builder, + _pd: PhantomData<fn(T, B)>, +} + +impl<C, B, T> Connect<C, B, T> { + /// Create a new `Connect` with some inner connector `C` and a connection + /// builder. + pub fn new(inner: C, builder: Builder) -> Self { + Self { + inner, + builder, + _pd: PhantomData, + } + } +} + +impl<C, B, T> Service<T> for Connect<C, B, T> +where + C: MakeConnection<T>, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into<Box<dyn StdError + Send + Sync>> + Send, + B: HttpBody + Unpin + Send + 'static, + B::Data: Send + Unpin, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = SendRequest<B>; + type Error = crate::Error; + type Future = + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner + .poll_ready(cx) + .map_err(|e| crate::Error::new(crate::error::Kind::Connect).with(e.into())) + } + + fn call(&mut self, req: T) -> Self::Future { + let builder = self.builder.clone(); + let io = self.inner.make_connection(req); + + let fut = async move { + match io.await { + Ok(io) => match builder.handshake(io).await { + Ok((sr, conn)) => { + builder.exec.execute(async move { + if let Err(e) = conn.await { + debug!("connection error: {:?}", e); + } + }); + Ok(sr) + } + Err(e) => Err(e), + }, + Err(e) => { + let err = crate::Error::new(crate::error::Kind::Connect).with(e.into()); + Err(err) + } + } + }; + + Box::pin(fut) + } +} diff --git a/third_party/rust/hyper/src/client/tests.rs b/third_party/rust/hyper/src/client/tests.rs new file mode 100644 index 0000000000..0a281a637d --- /dev/null +++ b/third_party/rust/hyper/src/client/tests.rs @@ -0,0 +1,286 @@ +use std::io; + +use futures_util::future; +use tokio::net::TcpStream; + +use super::Client; + +#[tokio::test] +async fn client_connect_uri_argument() { + let connector = tower::service_fn(|dst: http::Uri| { + assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP)); + assert_eq!(dst.host(), Some("example.local")); + assert_eq!(dst.port(), None); + assert_eq!(dst.path(), "/", "path should be removed"); + + future::err::<TcpStream, _>(io::Error::new(io::ErrorKind::Other, "expect me")) + }); + + let client = Client::builder().build::<_, crate::Body>(connector); + let _ = client + .get("http://example.local/and/a/path".parse().unwrap()) + .await + .expect_err("response should fail"); +} + +/* +// FIXME: re-implement tests with `async/await` +#[test] +fn retryable_request() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + let sock1 = connector.mock("http://mock.local"); + let sock2 = connector.mock("http://mock.local"); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + { + + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res1 = client.request(req); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + } + drop(sock1); + + let req = Request::builder() + .uri("http://mock.local/b") + .body(Default::default()) + .unwrap(); + let res2 = client.request(req) + .map(|res| { + assert_eq!(res.status().as_u16(), 222); + }); + let srv2 = poll_fn(|| { + try_ready!(sock2.read(&mut [0u8; 512])); + try_ready!(sock2.write(b"HTTP/1.1 222 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + + rt.block_on(res2.join(srv2)).expect("res2"); +} + +#[test] +fn conn_reset_after_write() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + let sock1 = connector.mock("http://mock.local"); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + { + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res1 = client.request(req); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + } + + let req = Request::builder() + .uri("http://mock.local/a") + .body(Default::default()) + .unwrap(); + let res2 = client.request(req); + let mut sock1 = Some(sock1); + let srv2 = poll_fn(|| { + // We purposefully keep the socket open until the client + // has written the second request, and THEN disconnect. + // + // Not because we expect servers to be jerks, but to trigger + // state where we write on an assumedly good connection, and + // only reset the close AFTER we wrote bytes. + try_ready!(sock1.as_mut().unwrap().read(&mut [0u8; 512])); + sock1.take(); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + let err = rt.block_on(res2.join(srv2)).expect_err("res2"); + assert!(err.is_incomplete_message(), "{:?}", err); +} + +#[test] +fn checkout_win_allows_connect_future_to_be_pooled() { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let (tx, rx) = oneshot::channel::<()>(); + let sock1 = connector.mock("http://mock.local"); + let sock2 = connector.mock_fut("http://mock.local", rx); + + let client = Client::builder() + .build::<_, crate::Body>(connector); + + client.pool.no_timer(); + + let uri = "http://mock.local/a".parse::<crate::Uri>().expect("uri parse"); + + // First request just sets us up to have a connection able to be put + // back in the pool. *However*, it doesn't insert immediately. The + // body has 1 pending byte, and we will only drain in request 2, once + // the connect future has been started. + let mut body = { + let res1 = client.get(uri.clone()) + .map(|res| res.into_body().concat2()); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + // Chunked is used so as to force 2 body reads. + try_ready!(sock1.write(b"\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + 1\r\nx\r\n\ + 0\r\n\r\n\ + ")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + + rt.block_on(res1.join(srv1)).expect("res1").0 + }; + + + // The second request triggers the only mocked connect future, but then + // the drained body allows the first socket to go back to the pool, + // "winning" the checkout race. + { + let res2 = client.get(uri.clone()); + let drain = poll_fn(move || { + body.poll() + }); + let srv2 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nx")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv2 poll_fn error: {}", e)); + + rt.block_on(res2.join(drain).join(srv2)).expect("res2"); + } + + // "Release" the mocked connect future, and let the runtime spin once so + // it's all setup... + { + let mut tx = Some(tx); + let client = &client; + let key = client.pool.h1_key("http://mock.local"); + let mut tick_cnt = 0; + let fut = poll_fn(move || { + tx.take(); + + if client.pool.idle_count(&key) == 0 { + tick_cnt += 1; + assert!(tick_cnt < 10, "ticked too many times waiting for idle"); + trace!("no idle yet; tick count: {}", tick_cnt); + ::futures::task::current().notify(); + Ok(Async::NotReady) + } else { + Ok::<_, ()>(Async::Ready(())) + } + }); + rt.block_on(fut).unwrap(); + } + + // Third request just tests out that the "loser" connection was pooled. If + // it isn't, this will panic since the MockConnector doesn't have any more + // mocks to give out. + { + let res3 = client.get(uri); + let srv3 = poll_fn(|| { + try_ready!(sock2.read(&mut [0u8; 512])); + try_ready!(sock2.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv3 poll_fn error: {}", e)); + + rt.block_on(res3.join(srv3)).expect("res3"); + } +} + +#[cfg(feature = "nightly")] +#[bench] +fn bench_http1_get_0b(b: &mut test::Bencher) { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let client = Client::builder() + .build::<_, crate::Body>(connector.clone()); + + client.pool.no_timer(); + + let uri = Uri::from_static("http://mock.local/a"); + + b.iter(move || { + let sock1 = connector.mock("http://mock.local"); + let res1 = client + .get(uri.clone()) + .and_then(|res| { + res.into_body().for_each(|_| Ok(())) + }); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + }); +} + +#[cfg(feature = "nightly")] +#[bench] +fn bench_http1_get_10b(b: &mut test::Bencher) { + let _ = pretty_env_logger::try_init(); + + let mut rt = Runtime::new().expect("new rt"); + let mut connector = MockConnector::new(); + + + let client = Client::builder() + .build::<_, crate::Body>(connector.clone()); + + client.pool.no_timer(); + + let uri = Uri::from_static("http://mock.local/a"); + + b.iter(move || { + let sock1 = connector.mock("http://mock.local"); + let res1 = client + .get(uri.clone()) + .and_then(|res| { + res.into_body().for_each(|_| Ok(())) + }); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n0123456789")); + Ok(Async::Ready(())) + }).map_err(|e: std::io::Error| panic!("srv1 poll_fn error: {}", e)); + rt.block_on(res1.join(srv1)).expect("res1"); + }); +} +*/ diff --git a/third_party/rust/hyper/src/common/buf.rs b/third_party/rust/hyper/src/common/buf.rs new file mode 100644 index 0000000000..64e9333ead --- /dev/null +++ b/third_party/rust/hyper/src/common/buf.rs @@ -0,0 +1,151 @@ +use std::collections::VecDeque; +use std::io::IoSlice; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub(crate) struct BufList<T> { + bufs: VecDeque<T>, +} + +impl<T: Buf> BufList<T> { + pub(crate) fn new() -> BufList<T> { + BufList { + bufs: VecDeque::new(), + } + } + + #[inline] + pub(crate) fn push(&mut self, buf: T) { + debug_assert!(buf.has_remaining()); + self.bufs.push_back(buf); + } + + #[inline] + #[cfg(feature = "http1")] + pub(crate) fn bufs_cnt(&self) -> usize { + self.bufs.len() + } +} + +impl<T: Buf> Buf for BufList<T> { + #[inline] + fn remaining(&self) -> usize { + self.bufs.iter().map(|buf| buf.remaining()).sum() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or_default() + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + { + let front = &mut self.bufs[0]; + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } + self.bufs.pop_front(); + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + if dst.is_empty() { + return 0; + } + let mut vecs = 0; + for buf in &self.bufs { + vecs += buf.chunks_vectored(&mut dst[vecs..]); + if vecs == dst.len() { + break; + } + } + vecs + } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + // Our inner buffer may have an optimized version of copy_to_bytes, and if the whole + // request can be fulfilled by the front buffer, we can take advantage. + match self.bufs.front_mut() { + Some(front) if front.remaining() == len => { + let b = front.copy_to_bytes(len); + self.bufs.pop_front(); + b + } + Some(front) if front.remaining() > len => front.copy_to_bytes(len), + _ => { + assert!(len <= self.remaining(), "`len` greater than remaining"); + let mut bm = BytesMut::with_capacity(len); + bm.put(self.take(len)); + bm.freeze() + } + } + } +} + +#[cfg(test)] +mod tests { + use std::ptr; + + use super::*; + + fn hello_world_buf() -> BufList<Bytes> { + BufList { + bufs: vec![Bytes::from("Hello"), Bytes::from(" "), Bytes::from("World")].into(), + } + } + + #[test] + fn to_bytes_shorter() { + let mut bufs = hello_world_buf(); + let old_ptr = bufs.chunk().as_ptr(); + let start = bufs.copy_to_bytes(4); + assert_eq!(start, "Hell"); + assert!(ptr::eq(old_ptr, start.as_ptr())); + assert_eq!(bufs.chunk(), b"o"); + assert!(ptr::eq(old_ptr.wrapping_add(4), bufs.chunk().as_ptr())); + assert_eq!(bufs.remaining(), 7); + } + + #[test] + fn to_bytes_eq() { + let mut bufs = hello_world_buf(); + let old_ptr = bufs.chunk().as_ptr(); + let start = bufs.copy_to_bytes(5); + assert_eq!(start, "Hello"); + assert!(ptr::eq(old_ptr, start.as_ptr())); + assert_eq!(bufs.chunk(), b" "); + assert_eq!(bufs.remaining(), 6); + } + + #[test] + fn to_bytes_longer() { + let mut bufs = hello_world_buf(); + let start = bufs.copy_to_bytes(7); + assert_eq!(start, "Hello W"); + assert_eq!(bufs.remaining(), 4); + } + + #[test] + fn one_long_buf_to_bytes() { + let mut buf = BufList::new(); + buf.push(b"Hello World" as &[_]); + assert_eq!(buf.copy_to_bytes(5), "Hello"); + assert_eq!(buf.chunk(), b" World"); + } + + #[test] + #[should_panic(expected = "`len` greater than remaining")] + fn buf_to_bytes_too_many() { + hello_world_buf().copy_to_bytes(42); + } +} diff --git a/third_party/rust/hyper/src/common/date.rs b/third_party/rust/hyper/src/common/date.rs new file mode 100644 index 0000000000..a436fc07c0 --- /dev/null +++ b/third_party/rust/hyper/src/common/date.rs @@ -0,0 +1,124 @@ +use std::cell::RefCell; +use std::fmt::{self, Write}; +use std::str; +use std::time::{Duration, SystemTime}; + +#[cfg(feature = "http2")] +use http::header::HeaderValue; +use httpdate::HttpDate; + +// "Sun, 06 Nov 1994 08:49:37 GMT".len() +pub(crate) const DATE_VALUE_LENGTH: usize = 29; + +#[cfg(feature = "http1")] +pub(crate) fn extend(dst: &mut Vec<u8>) { + CACHED.with(|cache| { + dst.extend_from_slice(cache.borrow().buffer()); + }) +} + +#[cfg(feature = "http1")] +pub(crate) fn update() { + CACHED.with(|cache| { + cache.borrow_mut().check(); + }) +} + +#[cfg(feature = "http2")] +pub(crate) fn update_and_header_value() -> HeaderValue { + CACHED.with(|cache| { + let mut cache = cache.borrow_mut(); + cache.check(); + HeaderValue::from_bytes(cache.buffer()).expect("Date format should be valid HeaderValue") + }) +} + +struct CachedDate { + bytes: [u8; DATE_VALUE_LENGTH], + pos: usize, + next_update: SystemTime, +} + +thread_local!(static CACHED: RefCell<CachedDate> = RefCell::new(CachedDate::new())); + +impl CachedDate { + fn new() -> Self { + let mut cache = CachedDate { + bytes: [0; DATE_VALUE_LENGTH], + pos: 0, + next_update: SystemTime::now(), + }; + cache.update(cache.next_update); + cache + } + + fn buffer(&self) -> &[u8] { + &self.bytes[..] + } + + fn check(&mut self) { + let now = SystemTime::now(); + if now > self.next_update { + self.update(now); + } + } + + fn update(&mut self, now: SystemTime) { + self.render(now); + self.next_update = now + Duration::new(1, 0); + } + + fn render(&mut self, now: SystemTime) { + self.pos = 0; + let _ = write!(self, "{}", HttpDate::from(now)); + debug_assert!(self.pos == DATE_VALUE_LENGTH); + } +} + +impl fmt::Write for CachedDate { + fn write_str(&mut self, s: &str) -> fmt::Result { + let len = s.len(); + self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); + self.pos += len; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[test] + fn test_date_len() { + assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_date_check(b: &mut Bencher) { + let mut date = CachedDate::new(); + // cache the first update + date.check(); + + b.iter(|| { + date.check(); + }); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_date_render(b: &mut Bencher) { + let mut date = CachedDate::new(); + let now = SystemTime::now(); + date.render(now); + b.bytes = date.buffer().len() as u64; + + b.iter(|| { + date.render(now); + test::black_box(&date); + }); + } +} diff --git a/third_party/rust/hyper/src/common/drain.rs b/third_party/rust/hyper/src/common/drain.rs new file mode 100644 index 0000000000..174da876df --- /dev/null +++ b/third_party/rust/hyper/src/common/drain.rs @@ -0,0 +1,217 @@ +use std::mem; + +use pin_project_lite::pin_project; +use tokio::sync::watch; + +use super::{task, Future, Pin, Poll}; + +pub(crate) fn channel() -> (Signal, Watch) { + let (tx, rx) = watch::channel(()); + (Signal { tx }, Watch { rx }) +} + +pub(crate) struct Signal { + tx: watch::Sender<()>, +} + +pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>); + +#[derive(Clone)] +pub(crate) struct Watch { + rx: watch::Receiver<()>, +} + +pin_project! { + #[allow(missing_debug_implementations)] + pub struct Watching<F, FN> { + #[pin] + future: F, + state: State<FN>, + watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>, + _rx: watch::Receiver<()>, + } +} + +enum State<F> { + Watch(F), + Draining, +} + +impl Signal { + pub(crate) fn drain(self) -> Draining { + let _ = self.tx.send(()); + Draining(Box::pin(async move { self.tx.closed().await })) + } +} + +impl Future for Draining { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.as_mut().0).poll(cx) + } +} + +impl Watch { + pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> + where + F: Future, + FN: FnOnce(Pin<&mut F>), + { + let Self { mut rx } = self; + let _rx = rx.clone(); + Watching { + future, + state: State::Watch(on_drain), + watch: Box::pin(async move { + let _ = rx.changed().await; + }), + // Keep the receiver alive until the future completes, so that + // dropping it can signal that draining has completed. + _rx, + } + } +} + +impl<F, FN> Future for Watching<F, FN> +where + F: Future, + FN: FnOnce(Pin<&mut F>), +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + match mem::replace(me.state, State::Draining) { + State::Watch(on_drain) => { + match Pin::new(&mut me.watch).poll(cx) { + Poll::Ready(()) => { + // Drain has been triggered! + on_drain(me.future.as_mut()); + } + Poll::Pending => { + *me.state = State::Watch(on_drain); + return me.future.poll(cx); + } + } + } + State::Draining => return me.future.poll(cx), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestMe { + draining: bool, + finished: bool, + poll_cnt: usize, + } + + impl Future for TestMe { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_cnt += 1; + if self.finished { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + + #[test] + fn watch() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + let fut = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let mut watch = rx.watch(fut, |mut fut| { + fut.draining = true; + }); + + assert_eq!(watch.future.poll_cnt, 0); + + // First poll should poll the inner future + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 1); + + // Second poll should poll the inner future again + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 2); + + let mut draining = tx.drain(); + // Drain signaled, but needs another poll to be noticed. + assert!(!watch.future.draining); + assert_eq!(watch.future.poll_cnt, 2); + + // Now, poll after drain has been signaled. + assert!(Pin::new(&mut watch).poll(cx).is_pending()); + assert_eq!(watch.future.poll_cnt, 3); + assert!(watch.future.draining); + + // Draining is not ready until watcher completes + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // Finishing up the watch future + watch.future.finished = true; + assert!(Pin::new(&mut watch).poll(cx).is_ready()); + assert_eq!(watch.future.poll_cnt, 4); + drop(watch); + + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }) + } + + #[test] + fn watch_clones() { + let mut mock = tokio_test::task::spawn(()); + mock.enter(|cx, _| { + let (tx, rx) = channel(); + + let fut1 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + let fut2 = TestMe { + draining: false, + finished: false, + poll_cnt: 0, + }; + + let watch1 = rx.clone().watch(fut1, |mut fut| { + fut.draining = true; + }); + let watch2 = rx.watch(fut2, |mut fut| { + fut.draining = true; + }); + + let mut draining = tx.drain(); + + // Still 2 outstanding watchers + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + // drop 1 for whatever reason + drop(watch1); + + // Still not ready, 1 other watcher still pending + assert!(Pin::new(&mut draining).poll(cx).is_pending()); + + drop(watch2); + + // Now all watchers are gone, draining is complete + assert!(Pin::new(&mut draining).poll(cx).is_ready()); + }); + } +} diff --git a/third_party/rust/hyper/src/common/exec.rs b/third_party/rust/hyper/src/common/exec.rs new file mode 100644 index 0000000000..76f616184b --- /dev/null +++ b/third_party/rust/hyper/src/common/exec.rs @@ -0,0 +1,145 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +use crate::body::Body; +#[cfg(feature = "server")] +use crate::body::HttpBody; +#[cfg(all(feature = "http2", feature = "server"))] +use crate::proto::h2::server::H2Stream; +use crate::rt::Executor; +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +use crate::server::server::{new_svc::NewSvcTask, Watcher}; +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +use crate::service::HttpService; + +#[cfg(feature = "server")] +pub trait ConnStreamExec<F, B: HttpBody>: Clone { + fn execute_h2stream(&mut self, fut: H2Stream<F, B>); +} + +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +pub trait NewSvcExec<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>>: Clone { + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>); +} + +pub(crate) type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; + +// Either the user provides an executor for background tasks, or we use +// `tokio::spawn`. +#[derive(Clone)] +pub enum Exec { + Default, + Executor(Arc<dyn Executor<BoxSendFuture> + Send + Sync>), +} + +// ===== impl Exec ===== + +impl Exec { + pub(crate) fn execute<F>(&self, fut: F) + where + F: Future<Output = ()> + Send + 'static, + { + match *self { + Exec::Default => { + #[cfg(feature = "tcp")] + { + tokio::task::spawn(fut); + } + #[cfg(not(feature = "tcp"))] + { + // If no runtime, we need an executor! + panic!("executor must be set") + } + } + Exec::Executor(ref e) => { + e.execute(Box::pin(fut)); + } + } + } +} + +impl fmt::Debug for Exec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Exec").finish() + } +} + +#[cfg(feature = "server")] +impl<F, B> ConnStreamExec<F, B> for Exec +where + H2Stream<F, B>: Future<Output = ()> + Send + 'static, + B: HttpBody, +{ + fn execute_h2stream(&mut self, fut: H2Stream<F, B>) { + self.execute(fut) + } +} + +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +impl<I, N, S, E, W> NewSvcExec<I, N, S, E, W> for Exec +where + NewSvcTask<I, N, S, E, W>: Future<Output = ()> + Send + 'static, + S: HttpService<Body>, + W: Watcher<I, S, E>, +{ + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>) { + self.execute(fut) + } +} + +// ==== impl Executor ===== + +#[cfg(feature = "server")] +impl<E, F, B> ConnStreamExec<F, B> for E +where + E: Executor<H2Stream<F, B>> + Clone, + H2Stream<F, B>: Future<Output = ()>, + B: HttpBody, +{ + fn execute_h2stream(&mut self, fut: H2Stream<F, B>) { + self.execute(fut) + } +} + +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +impl<I, N, S, E, W> NewSvcExec<I, N, S, E, W> for E +where + E: Executor<NewSvcTask<I, N, S, E, W>> + Clone, + NewSvcTask<I, N, S, E, W>: Future<Output = ()>, + S: HttpService<Body>, + W: Watcher<I, S, E>, +{ + fn execute_new_svc(&mut self, fut: NewSvcTask<I, N, S, E, W>) { + self.execute(fut) + } +} + +// If http2 is not enable, we just have a stub here, so that the trait bounds +// that *would* have been needed are still checked. Why? +// +// Because enabling `http2` shouldn't suddenly add new trait bounds that cause +// a compilation error. +#[cfg(not(feature = "http2"))] +#[allow(missing_debug_implementations)] +pub struct H2Stream<F, B>(std::marker::PhantomData<(F, B)>); + +#[cfg(not(feature = "http2"))] +impl<F, B, E> Future for H2Stream<F, B> +where + F: Future<Output = Result<http::Response<B>, E>>, + B: crate::body::HttpBody, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: Into<Box<dyn std::error::Error + Send + Sync>>, +{ + type Output = (); + + fn poll( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Self::Output> { + unreachable!() + } +} diff --git a/third_party/rust/hyper/src/common/io/mod.rs b/third_party/rust/hyper/src/common/io/mod.rs new file mode 100644 index 0000000000..2e6d506153 --- /dev/null +++ b/third_party/rust/hyper/src/common/io/mod.rs @@ -0,0 +1,3 @@ +mod rewind; + +pub(crate) use self::rewind::Rewind; diff --git a/third_party/rust/hyper/src/common/io/rewind.rs b/third_party/rust/hyper/src/common/io/rewind.rs new file mode 100644 index 0000000000..0afef5f7ea --- /dev/null +++ b/third_party/rust/hyper/src/common/io/rewind.rs @@ -0,0 +1,155 @@ +use std::marker::Unpin; +use std::{cmp, io}; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::common::{task, Pin, Poll}; + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub(crate) struct Rewind<T> { + pre: Option<Bytes>, + inner: T, +} + +impl<T> Rewind<T> { + #[cfg(any(all(feature = "http2", feature = "server"), test))] + pub(crate) fn new(io: T) -> Self { + Rewind { + pre: None, + inner: io, + } + } + + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } + + #[cfg(any(all(feature = "http1", feature = "http2", feature = "server"), test))] + pub(crate) fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } + + pub(crate) fn into_inner(self) -> (T, Bytes) { + (self.inner, self.pre.unwrap_or_else(Bytes::new)) + } + + // pub(crate) fn get_mut(&mut self) -> &mut T { + // &mut self.inner + // } +} + +impl<T> AsyncRead for Rewind<T> +where + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = cmp::min(prefix.len(), buf.remaining()); + // TODO: There should be a way to do following two lines cleaner... + buf.put_slice(&prefix[..copy_len]); + prefix.advance(copy_len); + // Put back what's left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(())); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl<T> AsyncWrite for Rewind<T> +where + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + // FIXME: re-implement tests with `async/await`, this import should + // trigger a warning to remind us + use super::Rewind; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn partial_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + // Read off some bytes, ensure we filled o1 + let mut buf = [0; 2]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&buf, &underlying); + } + + #[tokio::test] + async fn full_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + } +} diff --git a/third_party/rust/hyper/src/common/lazy.rs b/third_party/rust/hyper/src/common/lazy.rs new file mode 100644 index 0000000000..2722077303 --- /dev/null +++ b/third_party/rust/hyper/src/common/lazy.rs @@ -0,0 +1,76 @@ +use pin_project_lite::pin_project; + +use super::{task, Future, Pin, Poll}; + +pub(crate) trait Started: Future { + fn started(&self) -> bool; +} + +pub(crate) fn lazy<F, R>(func: F) -> Lazy<F, R> +where + F: FnOnce() -> R, + R: Future + Unpin, +{ + Lazy { + inner: Inner::Init { func }, + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +pin_project! { + #[allow(missing_debug_implementations)] + pub(crate) struct Lazy<F, R> { + #[pin] + inner: Inner<F, R>, + } +} + +pin_project! { + #[project = InnerProj] + #[project_replace = InnerProjReplace] + enum Inner<F, R> { + Init { func: F }, + Fut { #[pin] fut: R }, + Empty, + } +} + +impl<F, R> Started for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future, +{ + fn started(&self) -> bool { + match self.inner { + Inner::Init { .. } => false, + Inner::Fut { .. } | Inner::Empty => true, + } + } +} + +impl<F, R> Future for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future, +{ + type Output = R::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + + if let InnerProj::Fut { fut } = this.inner.as_mut().project() { + return fut.poll(cx); + } + + match this.inner.as_mut().project_replace(Inner::Empty) { + InnerProjReplace::Init { func } => { + this.inner.set(Inner::Fut { fut: func() }); + if let InnerProj::Fut { fut } = this.inner.project() { + return fut.poll(cx); + } + unreachable!() + } + _ => unreachable!("lazy state wrong"), + } + } +} diff --git a/third_party/rust/hyper/src/common/mod.rs b/third_party/rust/hyper/src/common/mod.rs new file mode 100644 index 0000000000..e38c6f5c7a --- /dev/null +++ b/third_party/rust/hyper/src/common/mod.rs @@ -0,0 +1,39 @@ +macro_rules! ready { + ($e:expr) => { + match $e { + std::task::Poll::Ready(v) => v, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + +pub(crate) mod buf; +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +pub(crate) mod date; +#[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] +pub(crate) mod drain; +#[cfg(any(feature = "http1", feature = "http2", feature = "server"))] +pub(crate) mod exec; +pub(crate) mod io; +#[cfg(all(feature = "client", any(feature = "http1", feature = "http2")))] +mod lazy; +mod never; +#[cfg(any( + feature = "stream", + all(feature = "client", any(feature = "http1", feature = "http2")) +))] +pub(crate) mod sync_wrapper; +pub(crate) mod task; +pub(crate) mod watch; + +#[cfg(all(feature = "client", any(feature = "http1", feature = "http2")))] +pub(crate) use self::lazy::{lazy, Started as Lazy}; +#[cfg(any(feature = "http1", feature = "http2", feature = "runtime"))] +pub(crate) use self::never::Never; +pub(crate) use self::task::Poll; + +// group up types normally needed for `Future` +cfg_proto! { + pub(crate) use std::marker::Unpin; +} +pub(crate) use std::{future::Future, pin::Pin}; diff --git a/third_party/rust/hyper/src/common/never.rs b/third_party/rust/hyper/src/common/never.rs new file mode 100644 index 0000000000..f143caf60f --- /dev/null +++ b/third_party/rust/hyper/src/common/never.rs @@ -0,0 +1,21 @@ +//! An uninhabitable type meaning it can never happen. +//! +//! To be replaced with `!` once it is stable. + +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub(crate) enum Never {} + +impl fmt::Display for Never { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self {} + } +} + +impl Error for Never { + fn description(&self) -> &str { + match *self {} + } +} diff --git a/third_party/rust/hyper/src/common/sync_wrapper.rs b/third_party/rust/hyper/src/common/sync_wrapper.rs new file mode 100644 index 0000000000..704d1a6712 --- /dev/null +++ b/third_party/rust/hyper/src/common/sync_wrapper.rs @@ -0,0 +1,110 @@ +/* + * This is a copy of the sync_wrapper crate. + */ + +/// A mutual exclusion primitive that relies on static type information only +/// +/// In some cases synchronization can be proven statically: whenever you hold an exclusive `&mut` +/// reference, the Rust type system ensures that no other part of the program can hold another +/// reference to the data. Therefore it is safe to access it even if the current thread obtained +/// this reference via a channel. Whenever this is the case, the overhead of allocating and locking +/// a [`Mutex`] can be avoided by using this static version. +/// +/// One example where this is often applicable is [`Future`], which requires an exclusive reference +/// for its [`poll`] method: While a given `Future` implementation may not be safe to access by +/// multiple threads concurrently, the executor can only run the `Future` on one thread at any +/// given time, making it [`Sync`] in practice as long as the implementation is `Send`. You can +/// therefore use the sync wrapper to prove that your data structure is `Sync` even though it +/// contains such a `Future`. +/// +/// # Example +/// +/// ```ignore +/// use hyper::common::sync_wrapper::SyncWrapper; +/// use std::future::Future; +/// +/// struct MyThing { +/// future: SyncWrapper<Box<dyn Future<Output = String> + Send>>, +/// } +/// +/// impl MyThing { +/// // all accesses to `self.future` now require an exclusive reference or ownership +/// } +/// +/// fn assert_sync<T: Sync>() {} +/// +/// assert_sync::<MyThing>(); +/// ``` +/// +/// [`Mutex`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html +/// [`Future`]: https://doc.rust-lang.org/std/future/trait.Future.html +/// [`poll`]: https://doc.rust-lang.org/std/future/trait.Future.html#method.poll +/// [`Sync`]: https://doc.rust-lang.org/std/marker/trait.Sync.html +#[repr(transparent)] +pub(crate) struct SyncWrapper<T>(T); + +impl<T> SyncWrapper<T> { + /// Creates a new SyncWrapper containing the given value. + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let wrapped = SyncWrapper::new(42); + /// ``` + pub(crate) fn new(value: T) -> Self { + Self(value) + } + + /// Acquires a reference to the protected value. + /// + /// This is safe because it requires an exclusive reference to the wrapper. Therefore this method + /// neither panics nor does it return an error. This is in contrast to [`Mutex::get_mut`] which + /// returns an error if another thread panicked while holding the lock. It is not recommended + /// to send an exclusive reference to a potentially damaged value to another thread for further + /// processing. + /// + /// [`Mutex::get_mut`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.get_mut + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let mut wrapped = SyncWrapper::new(42); + /// let value = wrapped.get_mut(); + /// *value = 0; + /// assert_eq!(*wrapped.get_mut(), 0); + /// ``` + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.0 + } + + /// Consumes this wrapper, returning the underlying data. + /// + /// This is safe because it requires ownership of the wrapper, aherefore this method will neither + /// panic nor does it return an error. This is in contrast to [`Mutex::into_inner`] which + /// returns an error if another thread panicked while holding the lock. It is not recommended + /// to send an exclusive reference to a potentially damaged value to another thread for further + /// processing. + /// + /// [`Mutex::into_inner`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.into_inner + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let mut wrapped = SyncWrapper::new(42); + /// assert_eq!(wrapped.into_inner(), 42); + /// ``` + #[allow(dead_code)] + pub(crate) fn into_inner(self) -> T { + self.0 + } +} + +// this is safe because the only operations permitted on this data structure require exclusive +// access or ownership +unsafe impl<T: Send> Sync for SyncWrapper<T> {} diff --git a/third_party/rust/hyper/src/common/task.rs b/third_party/rust/hyper/src/common/task.rs new file mode 100644 index 0000000000..ec70c957d6 --- /dev/null +++ b/third_party/rust/hyper/src/common/task.rs @@ -0,0 +1,12 @@ +#[cfg(feature = "http1")] +use super::Never; +pub(crate) use std::task::{Context, Poll}; + +/// A function to help "yield" a future, such that it is re-scheduled immediately. +/// +/// Useful for spin counts, so a future doesn't hog too much time. +#[cfg(feature = "http1")] +pub(crate) fn yield_now(cx: &mut Context<'_>) -> Poll<Never> { + cx.waker().wake_by_ref(); + Poll::Pending +} diff --git a/third_party/rust/hyper/src/common/watch.rs b/third_party/rust/hyper/src/common/watch.rs new file mode 100644 index 0000000000..ba17d551cb --- /dev/null +++ b/third_party/rust/hyper/src/common/watch.rs @@ -0,0 +1,73 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(crate) const CLOSED: usize = 0; + +pub(crate) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!( + initial != CLOSED, + "watch::channel initial state of 0 is reserved" + ); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + ( + Sender { + shared: shared.clone(), + }, + Receiver { shared }, + ) +} + +pub(crate) struct Sender { + shared: Arc<Shared>, +} + +pub(crate) struct Receiver { + shared: Arc<Shared>, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(crate) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/third_party/rust/hyper/src/error.rs b/third_party/rust/hyper/src/error.rs new file mode 100644 index 0000000000..468f24cb7a --- /dev/null +++ b/third_party/rust/hyper/src/error.rs @@ -0,0 +1,641 @@ +//! Error and Result module. +use std::error::Error as StdError; +use std::fmt; + +/// Result type often returned from methods that can have hyper `Error`s. +pub type Result<T> = std::result::Result<T, Error>; + +type Cause = Box<dyn StdError + Send + Sync>; + +/// Represents errors that can occur handling HTTP streams. +pub struct Error { + inner: Box<ErrorImpl>, +} + +struct ErrorImpl { + kind: Kind, + cause: Option<Cause>, +} + +#[derive(Debug)] +pub(super) enum Kind { + Parse(Parse), + User(User), + /// A message reached EOF, but is not complete. + #[allow(unused)] + IncompleteMessage, + /// A connection received a message (or bytes) when not waiting for one. + #[cfg(feature = "http1")] + UnexpectedMessage, + /// A pending item was dropped before ever being processed. + Canceled, + /// Indicates a channel (client or body sender) is closed. + ChannelClosed, + /// An `io::Error` that occurred while trying to read or write to a network stream. + #[cfg(any(feature = "http1", feature = "http2"))] + Io, + /// Error occurred while connecting. + #[allow(unused)] + Connect, + /// Error creating a TcpListener. + #[cfg(all(feature = "tcp", feature = "server"))] + Listen, + /// Error accepting on an Incoming stream. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + Accept, + /// User took too long to send headers + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + HeaderTimeout, + /// Error while reading a body from connection. + #[cfg(any(feature = "http1", feature = "http2", feature = "stream"))] + Body, + /// Error while writing a body to connection. + #[cfg(any(feature = "http1", feature = "http2"))] + BodyWrite, + /// Error calling AsyncWrite::shutdown() + #[cfg(feature = "http1")] + Shutdown, + + /// A general error from h2. + #[cfg(feature = "http2")] + Http2, +} + +#[derive(Debug)] +pub(super) enum Parse { + Method, + Version, + #[cfg(feature = "http1")] + VersionH2, + Uri, + #[cfg_attr(not(all(feature = "http1", feature = "server")), allow(unused))] + UriTooLong, + Header(Header), + TooLarge, + Status, + #[cfg_attr(debug_assertions, allow(unused))] + Internal, +} + +#[derive(Debug)] +pub(super) enum Header { + Token, + #[cfg(feature = "http1")] + ContentLengthInvalid, + #[cfg(all(feature = "http1", feature = "server"))] + TransferEncodingInvalid, + #[cfg(feature = "http1")] + TransferEncodingUnexpected, +} + +#[derive(Debug)] +pub(super) enum User { + /// Error calling user's HttpBody::poll_data(). + #[cfg(any(feature = "http1", feature = "http2"))] + Body, + /// The user aborted writing of the outgoing body. + BodyWriteAborted, + /// Error calling user's MakeService. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + MakeService, + /// Error from future of user's Service. + #[cfg(any(feature = "http1", feature = "http2"))] + Service, + /// User tried to send a certain header in an unexpected context. + /// + /// For example, sending both `content-length` and `transfer-encoding`. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + UnexpectedHeader, + /// User tried to create a Request with bad version. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + UnsupportedVersion, + /// User tried to create a CONNECT Request with the Client. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + UnsupportedRequestMethod, + /// User tried to respond with a 1xx (not 101) response code. + #[cfg(feature = "http1")] + #[cfg(feature = "server")] + UnsupportedStatusCode, + /// User tried to send a Request with Client with non-absolute URI. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + AbsoluteUriRequired, + + /// User tried polling for an upgrade that doesn't exist. + NoUpgrade, + + /// User polled for an upgrade, but low-level API is not using upgrades. + #[cfg(feature = "http1")] + ManualUpgrade, + + /// User called `server::Connection::without_shutdown()` on an HTTP/2 conn. + #[cfg(feature = "server")] + WithoutShutdownNonHttp1, + + /// The dispatch task is gone. + #[cfg(feature = "client")] + DispatchGone, + + /// User aborted in an FFI callback. + #[cfg(feature = "ffi")] + AbortedByCallback, +} + +// Sentinel type to indicate the error was caused by a timeout. +#[derive(Debug)] +pub(super) struct TimedOut; + +impl Error { + /// Returns true if this was an HTTP parse error. + pub fn is_parse(&self) -> bool { + matches!(self.inner.kind, Kind::Parse(_)) + } + + /// Returns true if this was an HTTP parse error caused by a message that was too large. + pub fn is_parse_too_large(&self) -> bool { + matches!( + self.inner.kind, + Kind::Parse(Parse::TooLarge) | Kind::Parse(Parse::UriTooLong) + ) + } + + /// Returns true if this was an HTTP parse error caused by an invalid response status code or + /// reason phrase. + pub fn is_parse_status(&self) -> bool { + matches!(self.inner.kind, Kind::Parse(Parse::Status)) + } + + /// Returns true if this error was caused by user code. + pub fn is_user(&self) -> bool { + matches!(self.inner.kind, Kind::User(_)) + } + + /// Returns true if this was about a `Request` that was canceled. + pub fn is_canceled(&self) -> bool { + matches!(self.inner.kind, Kind::Canceled) + } + + /// Returns true if a sender's channel is closed. + pub fn is_closed(&self) -> bool { + matches!(self.inner.kind, Kind::ChannelClosed) + } + + /// Returns true if this was an error from `Connect`. + pub fn is_connect(&self) -> bool { + matches!(self.inner.kind, Kind::Connect) + } + + /// Returns true if the connection closed before a message could complete. + pub fn is_incomplete_message(&self) -> bool { + matches!(self.inner.kind, Kind::IncompleteMessage) + } + + /// Returns true if the body write was aborted. + pub fn is_body_write_aborted(&self) -> bool { + matches!(self.inner.kind, Kind::User(User::BodyWriteAborted)) + } + + /// Returns true if the error was caused by a timeout. + pub fn is_timeout(&self) -> bool { + self.find_source::<TimedOut>().is_some() + } + + /// Consumes the error, returning its cause. + pub fn into_cause(self) -> Option<Box<dyn StdError + Send + Sync>> { + self.inner.cause + } + + pub(super) fn new(kind: Kind) -> Error { + Error { + inner: Box::new(ErrorImpl { kind, cause: None }), + } + } + + pub(super) fn with<C: Into<Cause>>(mut self, cause: C) -> Error { + self.inner.cause = Some(cause.into()); + self + } + + #[cfg(any(all(feature = "http1", feature = "server"), feature = "ffi"))] + pub(super) fn kind(&self) -> &Kind { + &self.inner.kind + } + + pub(crate) fn find_source<E: StdError + 'static>(&self) -> Option<&E> { + let mut cause = self.source(); + while let Some(err) = cause { + if let Some(ref typed) = err.downcast_ref() { + return Some(typed); + } + cause = err.source(); + } + + // else + None + } + + #[cfg(feature = "http2")] + pub(super) fn h2_reason(&self) -> h2::Reason { + // Find an h2::Reason somewhere in the cause stack, if it exists, + // otherwise assume an INTERNAL_ERROR. + self.find_source::<h2::Error>() + .and_then(|h2_err| h2_err.reason()) + .unwrap_or(h2::Reason::INTERNAL_ERROR) + } + + pub(super) fn new_canceled() -> Error { + Error::new(Kind::Canceled) + } + + #[cfg(feature = "http1")] + pub(super) fn new_incomplete() -> Error { + Error::new(Kind::IncompleteMessage) + } + + #[cfg(feature = "http1")] + pub(super) fn new_too_large() -> Error { + Error::new(Kind::Parse(Parse::TooLarge)) + } + + #[cfg(feature = "http1")] + pub(super) fn new_version_h2() -> Error { + Error::new(Kind::Parse(Parse::VersionH2)) + } + + #[cfg(feature = "http1")] + pub(super) fn new_unexpected_message() -> Error { + Error::new(Kind::UnexpectedMessage) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + pub(super) fn new_io(cause: std::io::Error) -> Error { + Error::new(Kind::Io).with(cause) + } + + #[cfg(all(feature = "server", feature = "tcp"))] + pub(super) fn new_listen<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Listen).with(cause) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + pub(super) fn new_accept<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Accept).with(cause) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + pub(super) fn new_connect<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Connect).with(cause) + } + + pub(super) fn new_closed() -> Error { + Error::new(Kind::ChannelClosed) + } + + #[cfg(any(feature = "http1", feature = "http2", feature = "stream"))] + pub(super) fn new_body<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::Body).with(cause) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + pub(super) fn new_body_write<E: Into<Cause>>(cause: E) -> Error { + Error::new(Kind::BodyWrite).with(cause) + } + + pub(super) fn new_body_write_aborted() -> Error { + Error::new(Kind::User(User::BodyWriteAborted)) + } + + fn new_user(user: User) -> Error { + Error::new(Kind::User(user)) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + pub(super) fn new_user_header() -> Error { + Error::new_user(User::UnexpectedHeader) + } + + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + pub(super) fn new_header_timeout() -> Error { + Error::new(Kind::HeaderTimeout) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + pub(super) fn new_user_unsupported_version() -> Error { + Error::new_user(User::UnsupportedVersion) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + pub(super) fn new_user_unsupported_request_method() -> Error { + Error::new_user(User::UnsupportedRequestMethod) + } + + #[cfg(feature = "http1")] + #[cfg(feature = "server")] + pub(super) fn new_user_unsupported_status_code() -> Error { + Error::new_user(User::UnsupportedStatusCode) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + pub(super) fn new_user_absolute_uri_required() -> Error { + Error::new_user(User::AbsoluteUriRequired) + } + + pub(super) fn new_user_no_upgrade() -> Error { + Error::new_user(User::NoUpgrade) + } + + #[cfg(feature = "http1")] + pub(super) fn new_user_manual_upgrade() -> Error { + Error::new_user(User::ManualUpgrade) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + pub(super) fn new_user_make_service<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::MakeService).with(cause) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + pub(super) fn new_user_service<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::Service).with(cause) + } + + #[cfg(any(feature = "http1", feature = "http2"))] + pub(super) fn new_user_body<E: Into<Cause>>(cause: E) -> Error { + Error::new_user(User::Body).with(cause) + } + + #[cfg(feature = "server")] + pub(super) fn new_without_shutdown_not_h1() -> Error { + Error::new(Kind::User(User::WithoutShutdownNonHttp1)) + } + + #[cfg(feature = "http1")] + pub(super) fn new_shutdown(cause: std::io::Error) -> Error { + Error::new(Kind::Shutdown).with(cause) + } + + #[cfg(feature = "ffi")] + pub(super) fn new_user_aborted_by_callback() -> Error { + Error::new_user(User::AbortedByCallback) + } + + #[cfg(feature = "client")] + pub(super) fn new_user_dispatch_gone() -> Error { + Error::new(Kind::User(User::DispatchGone)) + } + + #[cfg(feature = "http2")] + pub(super) fn new_h2(cause: ::h2::Error) -> Error { + if cause.is_io() { + Error::new_io(cause.into_io().expect("h2::Error::is_io")) + } else { + Error::new(Kind::Http2).with(cause) + } + } + + /// The error's standalone message, without the message from the source. + pub fn message(&self) -> impl fmt::Display + '_ { + self.description() + } + + fn description(&self) -> &str { + match self.inner.kind { + Kind::Parse(Parse::Method) => "invalid HTTP method parsed", + Kind::Parse(Parse::Version) => "invalid HTTP version parsed", + #[cfg(feature = "http1")] + Kind::Parse(Parse::VersionH2) => "invalid HTTP version parsed (found HTTP2 preface)", + Kind::Parse(Parse::Uri) => "invalid URI", + Kind::Parse(Parse::UriTooLong) => "URI too long", + Kind::Parse(Parse::Header(Header::Token)) => "invalid HTTP header parsed", + #[cfg(feature = "http1")] + Kind::Parse(Parse::Header(Header::ContentLengthInvalid)) => { + "invalid content-length parsed" + } + #[cfg(all(feature = "http1", feature = "server"))] + Kind::Parse(Parse::Header(Header::TransferEncodingInvalid)) => { + "invalid transfer-encoding parsed" + } + #[cfg(feature = "http1")] + Kind::Parse(Parse::Header(Header::TransferEncodingUnexpected)) => { + "unexpected transfer-encoding parsed" + } + Kind::Parse(Parse::TooLarge) => "message head is too large", + Kind::Parse(Parse::Status) => "invalid HTTP status-code parsed", + Kind::Parse(Parse::Internal) => { + "internal error inside Hyper and/or its dependencies, please report" + } + Kind::IncompleteMessage => "connection closed before message completed", + #[cfg(feature = "http1")] + Kind::UnexpectedMessage => "received unexpected message from connection", + Kind::ChannelClosed => "channel closed", + Kind::Connect => "error trying to connect", + Kind::Canceled => "operation was canceled", + #[cfg(all(feature = "server", feature = "tcp"))] + Kind::Listen => "error creating server listener", + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + Kind::Accept => "error accepting connection", + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + Kind::HeaderTimeout => "read header from client timeout", + #[cfg(any(feature = "http1", feature = "http2", feature = "stream"))] + Kind::Body => "error reading a body from connection", + #[cfg(any(feature = "http1", feature = "http2"))] + Kind::BodyWrite => "error writing a body to connection", + #[cfg(feature = "http1")] + Kind::Shutdown => "error shutting down connection", + #[cfg(feature = "http2")] + Kind::Http2 => "http2 error", + #[cfg(any(feature = "http1", feature = "http2"))] + Kind::Io => "connection error", + + #[cfg(any(feature = "http1", feature = "http2"))] + Kind::User(User::Body) => "error from user's HttpBody stream", + Kind::User(User::BodyWriteAborted) => "user body write aborted", + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + Kind::User(User::MakeService) => "error from user's MakeService", + #[cfg(any(feature = "http1", feature = "http2"))] + Kind::User(User::Service) => "error from user's Service", + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "server")] + Kind::User(User::UnexpectedHeader) => "user sent unexpected header", + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Kind::User(User::UnsupportedVersion) => "request has unsupported HTTP version", + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Kind::User(User::UnsupportedRequestMethod) => "request has unsupported HTTP method", + #[cfg(feature = "http1")] + #[cfg(feature = "server")] + Kind::User(User::UnsupportedStatusCode) => { + "response has 1xx status code, not supported by server" + } + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg(feature = "client")] + Kind::User(User::AbsoluteUriRequired) => "client requires absolute-form URIs", + Kind::User(User::NoUpgrade) => "no upgrade available", + #[cfg(feature = "http1")] + Kind::User(User::ManualUpgrade) => "upgrade expected but low level API in use", + #[cfg(feature = "server")] + Kind::User(User::WithoutShutdownNonHttp1) => { + "without_shutdown() called on a non-HTTP/1 connection" + } + #[cfg(feature = "client")] + Kind::User(User::DispatchGone) => "dispatch task is gone", + #[cfg(feature = "ffi")] + Kind::User(User::AbortedByCallback) => "operation aborted by an application callback", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("hyper::Error"); + f.field(&self.inner.kind); + if let Some(ref cause) = self.inner.cause { + f.field(cause); + } + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref cause) = self.inner.cause { + write!(f, "{}: {}", self.description(), cause) + } else { + f.write_str(self.description()) + } + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .cause + .as_ref() + .map(|cause| &**cause as &(dyn StdError + 'static)) + } +} + +#[doc(hidden)] +impl From<Parse> for Error { + fn from(err: Parse) -> Error { + Error::new(Kind::Parse(err)) + } +} + +#[cfg(feature = "http1")] +impl Parse { + pub(crate) fn content_length_invalid() -> Self { + Parse::Header(Header::ContentLengthInvalid) + } + + #[cfg(all(feature = "http1", feature = "server"))] + pub(crate) fn transfer_encoding_invalid() -> Self { + Parse::Header(Header::TransferEncodingInvalid) + } + + pub(crate) fn transfer_encoding_unexpected() -> Self { + Parse::Header(Header::TransferEncodingUnexpected) + } +} + +impl From<httparse::Error> for Parse { + fn from(err: httparse::Error) -> Parse { + match err { + httparse::Error::HeaderName + | httparse::Error::HeaderValue + | httparse::Error::NewLine + | httparse::Error::Token => Parse::Header(Header::Token), + httparse::Error::Status => Parse::Status, + httparse::Error::TooManyHeaders => Parse::TooLarge, + httparse::Error::Version => Parse::Version, + } + } +} + +impl From<http::method::InvalidMethod> for Parse { + fn from(_: http::method::InvalidMethod) -> Parse { + Parse::Method + } +} + +impl From<http::status::InvalidStatusCode> for Parse { + fn from(_: http::status::InvalidStatusCode) -> Parse { + Parse::Status + } +} + +impl From<http::uri::InvalidUri> for Parse { + fn from(_: http::uri::InvalidUri) -> Parse { + Parse::Uri + } +} + +impl From<http::uri::InvalidUriParts> for Parse { + fn from(_: http::uri::InvalidUriParts) -> Parse { + Parse::Uri + } +} + +#[doc(hidden)] +trait AssertSendSync: Send + Sync + 'static {} +#[doc(hidden)] +impl AssertSendSync for Error {} + +// ===== impl TimedOut ==== + +impl fmt::Display for TimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("operation timed out") + } +} + +impl StdError for TimedOut {} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn error_size_of() { + assert_eq!(mem::size_of::<Error>(), mem::size_of::<usize>()); + } + + #[cfg(feature = "http2")] + #[test] + fn h2_reason_unknown() { + let closed = Error::new_closed(); + assert_eq!(closed.h2_reason(), h2::Reason::INTERNAL_ERROR); + } + + #[cfg(feature = "http2")] + #[test] + fn h2_reason_one_level() { + let body_err = Error::new_user_body(h2::Error::from(h2::Reason::ENHANCE_YOUR_CALM)); + assert_eq!(body_err.h2_reason(), h2::Reason::ENHANCE_YOUR_CALM); + } + + #[cfg(feature = "http2")] + #[test] + fn h2_reason_nested() { + let recvd = Error::new_h2(h2::Error::from(h2::Reason::HTTP_1_1_REQUIRED)); + // Suppose a user were proxying the received error + let svc_err = Error::new_user_service(recvd); + assert_eq!(svc_err.h2_reason(), h2::Reason::HTTP_1_1_REQUIRED); + } +} diff --git a/third_party/rust/hyper/src/ext.rs b/third_party/rust/hyper/src/ext.rs new file mode 100644 index 0000000000..224206dd66 --- /dev/null +++ b/third_party/rust/hyper/src/ext.rs @@ -0,0 +1,228 @@ +//! HTTP extensions. + +use bytes::Bytes; +#[cfg(any(feature = "http1", feature = "ffi"))] +use http::header::HeaderName; +#[cfg(feature = "http1")] +use http::header::{IntoHeaderName, ValueIter}; +use http::HeaderMap; +#[cfg(feature = "ffi")] +use std::collections::HashMap; +#[cfg(feature = "http2")] +use std::fmt; + +#[cfg(any(feature = "http1", feature = "ffi"))] +mod h1_reason_phrase; +#[cfg(any(feature = "http1", feature = "ffi"))] +pub use h1_reason_phrase::ReasonPhrase; + +#[cfg(feature = "http2")] +/// Represents the `:protocol` pseudo-header used by +/// the [Extended CONNECT Protocol]. +/// +/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 +#[derive(Clone, Eq, PartialEq)] +pub struct Protocol { + inner: h2::ext::Protocol, +} + +#[cfg(feature = "http2")] +impl Protocol { + /// Converts a static string to a protocol name. + pub const fn from_static(value: &'static str) -> Self { + Self { + inner: h2::ext::Protocol::from_static(value), + } + } + + /// Returns a str representation of the header. + pub fn as_str(&self) -> &str { + self.inner.as_str() + } + + #[cfg(feature = "server")] + pub(crate) fn from_inner(inner: h2::ext::Protocol) -> Self { + Self { inner } + } + + pub(crate) fn into_inner(self) -> h2::ext::Protocol { + self.inner + } +} + +#[cfg(feature = "http2")] +impl<'a> From<&'a str> for Protocol { + fn from(value: &'a str) -> Self { + Self { + inner: h2::ext::Protocol::from(value), + } + } +} + +#[cfg(feature = "http2")] +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} + +#[cfg(feature = "http2")] +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +/// A map from header names to their original casing as received in an HTTP message. +/// +/// If an HTTP/1 response `res` is parsed on a connection whose option +/// [`http1_preserve_header_case`] was set to true and the response included +/// the following headers: +/// +/// ```ignore +/// x-Bread: Baguette +/// X-BREAD: Pain +/// x-bread: Ficelle +/// ``` +/// +/// Then `res.extensions().get::<HeaderCaseMap>()` will return a map with: +/// +/// ```ignore +/// HeaderCaseMap({ +/// "x-bread": ["x-Bread", "X-BREAD", "x-bread"], +/// }) +/// ``` +/// +/// [`http1_preserve_header_case`]: /client/struct.Client.html#method.http1_preserve_header_case +#[derive(Clone, Debug)] +pub(crate) struct HeaderCaseMap(HeaderMap<Bytes>); + +#[cfg(feature = "http1")] +impl HeaderCaseMap { + /// Returns a view of all spellings associated with that header name, + /// in the order they were found. + pub(crate) fn get_all<'a>( + &'a self, + name: &HeaderName, + ) -> impl Iterator<Item = impl AsRef<[u8]> + 'a> + 'a { + self.get_all_internal(name).into_iter() + } + + /// Returns a view of all spellings associated with that header name, + /// in the order they were found. + pub(crate) fn get_all_internal<'a>(&'a self, name: &HeaderName) -> ValueIter<'_, Bytes> { + self.0.get_all(name).into_iter() + } + + pub(crate) fn default() -> Self { + Self(Default::default()) + } + + #[cfg(any(test, feature = "ffi"))] + pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { + self.0.insert(name, orig); + } + + pub(crate) fn append<N>(&mut self, name: N, orig: Bytes) + where + N: IntoHeaderName, + { + self.0.append(name, orig); + } +} + +#[cfg(feature = "ffi")] +#[derive(Clone, Debug)] +/// Hashmap<Headername, numheaders with that name> +pub(crate) struct OriginalHeaderOrder { + /// Stores how many entries a Headername maps to. This is used + /// for accounting. + num_entries: HashMap<HeaderName, usize>, + /// Stores the ordering of the headers. ex: `vec[i] = (headerName, idx)`, + /// The vector is ordered such that the ith element + /// represents the ith header that came in off the line. + /// The `HeaderName` and `idx` are then used elsewhere to index into + /// the multi map that stores the header values. + entry_order: Vec<(HeaderName, usize)>, +} + +#[cfg(all(feature = "http1", feature = "ffi"))] +impl OriginalHeaderOrder { + pub(crate) fn default() -> Self { + OriginalHeaderOrder { + num_entries: HashMap::new(), + entry_order: Vec::new(), + } + } + + pub(crate) fn insert(&mut self, name: HeaderName) { + if !self.num_entries.contains_key(&name) { + let idx = 0; + self.num_entries.insert(name.clone(), 1); + self.entry_order.push((name, idx)); + } + // Replacing an already existing element does not + // change ordering, so we only care if its the first + // header name encountered + } + + pub(crate) fn append<N>(&mut self, name: N) + where + N: IntoHeaderName + Into<HeaderName> + Clone, + { + let name: HeaderName = name.into(); + let idx; + if self.num_entries.contains_key(&name) { + idx = self.num_entries[&name]; + *self.num_entries.get_mut(&name).unwrap() += 1; + } else { + idx = 0; + self.num_entries.insert(name.clone(), 1); + } + self.entry_order.push((name, idx)); + } + + // No doc test is run here because `RUSTFLAGS='--cfg hyper_unstable_ffi'` + // is needed to compile. Once ffi is stablized `no_run` should be removed + // here. + /// This returns an iterator that provides header names and indexes + /// in the original order received. + /// + /// # Examples + /// ```no_run + /// use hyper::ext::OriginalHeaderOrder; + /// use hyper::header::{HeaderName, HeaderValue, HeaderMap}; + /// + /// let mut h_order = OriginalHeaderOrder::default(); + /// let mut h_map = Headermap::new(); + /// + /// let name1 = b"Set-CookiE"; + /// let value1 = b"a=b"; + /// h_map.append(name1); + /// h_order.append(name1); + /// + /// let name2 = b"Content-Encoding"; + /// let value2 = b"gzip"; + /// h_map.append(name2, value2); + /// h_order.append(name2); + /// + /// let name3 = b"SET-COOKIE"; + /// let value3 = b"c=d"; + /// h_map.append(name3, value3); + /// h_order.append(name3) + /// + /// let mut iter = h_order.get_in_order() + /// + /// let (name, idx) = iter.next(); + /// assert_eq!(b"a=b", h_map.get_all(name).nth(idx).unwrap()); + /// + /// let (name, idx) = iter.next(); + /// assert_eq!(b"gzip", h_map.get_all(name).nth(idx).unwrap()); + /// + /// let (name, idx) = iter.next(); + /// assert_eq!(b"c=d", h_map.get_all(name).nth(idx).unwrap()); + /// ``` + pub(crate) fn get_in_order(&self) -> impl Iterator<Item = &(HeaderName, usize)> { + self.entry_order.iter() + } +} diff --git a/third_party/rust/hyper/src/ext/h1_reason_phrase.rs b/third_party/rust/hyper/src/ext/h1_reason_phrase.rs new file mode 100644 index 0000000000..021b632b6d --- /dev/null +++ b/third_party/rust/hyper/src/ext/h1_reason_phrase.rs @@ -0,0 +1,221 @@ +use std::convert::TryFrom; + +use bytes::Bytes; + +/// A reason phrase in an HTTP/1 response. +/// +/// # Clients +/// +/// For clients, a `ReasonPhrase` will be present in the extensions of the `http::Response` returned +/// for a request if the reason phrase is different from the canonical reason phrase for the +/// response's status code. For example, if a server returns `HTTP/1.1 200 Awesome`, the +/// `ReasonPhrase` will be present and contain `Awesome`, but if a server returns `HTTP/1.1 200 OK`, +/// the response will not contain a `ReasonPhrase`. +/// +/// ```no_run +/// # #[cfg(all(feature = "tcp", feature = "client", feature = "http1"))] +/// # async fn fake_fetch() -> hyper::Result<()> { +/// use hyper::{Client, Uri}; +/// use hyper::ext::ReasonPhrase; +/// +/// let res = Client::new().get(Uri::from_static("http://example.com/non_canonical_reason")).await?; +/// +/// // Print out the non-canonical reason phrase, if it has one... +/// if let Some(reason) = res.extensions().get::<ReasonPhrase>() { +/// println!("non-canonical reason: {}", std::str::from_utf8(reason.as_bytes()).unwrap()); +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// # Servers +/// +/// When a `ReasonPhrase` is present in the extensions of the `http::Response` written by a server, +/// its contents will be written in place of the canonical reason phrase when responding via HTTP/1. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ReasonPhrase(Bytes); + +impl ReasonPhrase { + /// Gets the reason phrase as bytes. + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Converts a static byte slice to a reason phrase. + pub fn from_static(reason: &'static [u8]) -> Self { + // TODO: this can be made const once MSRV is >= 1.57.0 + if find_invalid_byte(reason).is_some() { + panic!("invalid byte in static reason phrase"); + } + Self(Bytes::from_static(reason)) + } + + /// Converts a `Bytes` directly into a `ReasonPhrase` without validating. + /// + /// Use with care; invalid bytes in a reason phrase can cause serious security problems if + /// emitted in a response. + pub unsafe fn from_bytes_unchecked(reason: Bytes) -> Self { + Self(reason) + } +} + +impl TryFrom<&[u8]> for ReasonPhrase { + type Error = InvalidReasonPhrase; + + fn try_from(reason: &[u8]) -> Result<Self, Self::Error> { + if let Some(bad_byte) = find_invalid_byte(reason) { + Err(InvalidReasonPhrase { bad_byte }) + } else { + Ok(Self(Bytes::copy_from_slice(reason))) + } + } +} + +impl TryFrom<Vec<u8>> for ReasonPhrase { + type Error = InvalidReasonPhrase; + + fn try_from(reason: Vec<u8>) -> Result<Self, Self::Error> { + if let Some(bad_byte) = find_invalid_byte(&reason) { + Err(InvalidReasonPhrase { bad_byte }) + } else { + Ok(Self(Bytes::from(reason))) + } + } +} + +impl TryFrom<String> for ReasonPhrase { + type Error = InvalidReasonPhrase; + + fn try_from(reason: String) -> Result<Self, Self::Error> { + if let Some(bad_byte) = find_invalid_byte(reason.as_bytes()) { + Err(InvalidReasonPhrase { bad_byte }) + } else { + Ok(Self(Bytes::from(reason))) + } + } +} + +impl TryFrom<Bytes> for ReasonPhrase { + type Error = InvalidReasonPhrase; + + fn try_from(reason: Bytes) -> Result<Self, Self::Error> { + if let Some(bad_byte) = find_invalid_byte(&reason) { + Err(InvalidReasonPhrase { bad_byte }) + } else { + Ok(Self(reason)) + } + } +} + +impl Into<Bytes> for ReasonPhrase { + fn into(self) -> Bytes { + self.0 + } +} + +impl AsRef<[u8]> for ReasonPhrase { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +/// Error indicating an invalid byte when constructing a `ReasonPhrase`. +/// +/// See [the spec][spec] for details on allowed bytes. +/// +/// [spec]: https://httpwg.org/http-core/draft-ietf-httpbis-messaging-latest.html#rfc.section.4.p.7 +#[derive(Debug)] +pub struct InvalidReasonPhrase { + bad_byte: u8, +} + +impl std::fmt::Display for InvalidReasonPhrase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Invalid byte in reason phrase: {}", self.bad_byte) + } +} + +impl std::error::Error for InvalidReasonPhrase {} + +const fn is_valid_byte(b: u8) -> bool { + // See https://www.rfc-editor.org/rfc/rfc5234.html#appendix-B.1 + const fn is_vchar(b: u8) -> bool { + 0x21 <= b && b <= 0x7E + } + + // See https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#fields.values + // + // The 0xFF comparison is technically redundant, but it matches the text of the spec more + // clearly and will be optimized away. + #[allow(unused_comparisons)] + const fn is_obs_text(b: u8) -> bool { + 0x80 <= b && b <= 0xFF + } + + // See https://httpwg.org/http-core/draft-ietf-httpbis-messaging-latest.html#rfc.section.4.p.7 + b == b'\t' || b == b' ' || is_vchar(b) || is_obs_text(b) +} + +const fn find_invalid_byte(bytes: &[u8]) -> Option<u8> { + let mut i = 0; + while i < bytes.len() { + let b = bytes[i]; + if !is_valid_byte(b) { + return Some(b); + } + i += 1; + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic_valid() { + const PHRASE: &'static [u8] = b"OK"; + assert_eq!(ReasonPhrase::from_static(PHRASE).as_bytes(), PHRASE); + assert_eq!(ReasonPhrase::try_from(PHRASE).unwrap().as_bytes(), PHRASE); + } + + #[test] + fn empty_valid() { + const PHRASE: &'static [u8] = b""; + assert_eq!(ReasonPhrase::from_static(PHRASE).as_bytes(), PHRASE); + assert_eq!(ReasonPhrase::try_from(PHRASE).unwrap().as_bytes(), PHRASE); + } + + #[test] + fn obs_text_valid() { + const PHRASE: &'static [u8] = b"hyp\xe9r"; + assert_eq!(ReasonPhrase::from_static(PHRASE).as_bytes(), PHRASE); + assert_eq!(ReasonPhrase::try_from(PHRASE).unwrap().as_bytes(), PHRASE); + } + + const NEWLINE_PHRASE: &'static [u8] = b"hyp\ner"; + + #[test] + #[should_panic] + fn newline_invalid_panic() { + ReasonPhrase::from_static(NEWLINE_PHRASE); + } + + #[test] + fn newline_invalid_err() { + assert!(ReasonPhrase::try_from(NEWLINE_PHRASE).is_err()); + } + + const CR_PHRASE: &'static [u8] = b"hyp\rer"; + + #[test] + #[should_panic] + fn cr_invalid_panic() { + ReasonPhrase::from_static(CR_PHRASE); + } + + #[test] + fn cr_invalid_err() { + assert!(ReasonPhrase::try_from(CR_PHRASE).is_err()); + } +} diff --git a/third_party/rust/hyper/src/ffi/body.rs b/third_party/rust/hyper/src/ffi/body.rs new file mode 100644 index 0000000000..39ba5beffb --- /dev/null +++ b/third_party/rust/hyper/src/ffi/body.rs @@ -0,0 +1,229 @@ +use std::ffi::c_void; +use std::mem::ManuallyDrop; +use std::ptr; +use std::task::{Context, Poll}; + +use http::HeaderMap; +use libc::{c_int, size_t}; + +use super::task::{hyper_context, hyper_task, hyper_task_return_type, AsTaskType}; +use super::{UserDataPointer, HYPER_ITER_CONTINUE}; +use crate::body::{Body, Bytes, HttpBody as _}; + +/// A streaming HTTP body. +pub struct hyper_body(pub(super) Body); + +/// A buffer of bytes that is sent or received on a `hyper_body`. +pub struct hyper_buf(pub(crate) Bytes); + +pub(crate) struct UserBody { + data_func: hyper_body_data_callback, + userdata: *mut c_void, +} + +// ===== Body ===== + +type hyper_body_foreach_callback = extern "C" fn(*mut c_void, *const hyper_buf) -> c_int; + +type hyper_body_data_callback = + extern "C" fn(*mut c_void, *mut hyper_context<'_>, *mut *mut hyper_buf) -> c_int; + +ffi_fn! { + /// Create a new "empty" body. + /// + /// If not configured, this body acts as an empty payload. + fn hyper_body_new() -> *mut hyper_body { + Box::into_raw(Box::new(hyper_body(Body::empty()))) + } ?= ptr::null_mut() +} + +ffi_fn! { + /// Free a `hyper_body *`. + fn hyper_body_free(body: *mut hyper_body) { + drop(non_null!(Box::from_raw(body) ?= ())); + } +} + +ffi_fn! { + /// Return a task that will poll the body for the next buffer of data. + /// + /// The task value may have different types depending on the outcome: + /// + /// - `HYPER_TASK_BUF`: Success, and more data was received. + /// - `HYPER_TASK_ERROR`: An error retrieving the data. + /// - `HYPER_TASK_EMPTY`: The body has finished streaming data. + /// + /// This does not consume the `hyper_body *`, so it may be used to again. + /// However, it MUST NOT be used or freed until the related task completes. + fn hyper_body_data(body: *mut hyper_body) -> *mut hyper_task { + // This doesn't take ownership of the Body, so don't allow destructor + let mut body = ManuallyDrop::new(non_null!(Box::from_raw(body) ?= ptr::null_mut())); + + Box::into_raw(hyper_task::boxed(async move { + body.0.data().await.map(|res| res.map(hyper_buf)) + })) + } ?= ptr::null_mut() +} + +ffi_fn! { + /// Return a task that will poll the body and execute the callback with each + /// body chunk that is received. + /// + /// The `hyper_buf` pointer is only a borrowed reference, it cannot live outside + /// the execution of the callback. You must make a copy to retain it. + /// + /// The callback should return `HYPER_ITER_CONTINUE` to continue iterating + /// chunks as they are received, or `HYPER_ITER_BREAK` to cancel. + /// + /// This will consume the `hyper_body *`, you shouldn't use it anymore or free it. + fn hyper_body_foreach(body: *mut hyper_body, func: hyper_body_foreach_callback, userdata: *mut c_void) -> *mut hyper_task { + let mut body = non_null!(Box::from_raw(body) ?= ptr::null_mut()); + let userdata = UserDataPointer(userdata); + + Box::into_raw(hyper_task::boxed(async move { + while let Some(item) = body.0.data().await { + let chunk = item?; + if HYPER_ITER_CONTINUE != func(userdata.0, &hyper_buf(chunk)) { + return Err(crate::Error::new_user_aborted_by_callback()); + } + } + Ok(()) + })) + } ?= ptr::null_mut() +} + +ffi_fn! { + /// Set userdata on this body, which will be passed to callback functions. + fn hyper_body_set_userdata(body: *mut hyper_body, userdata: *mut c_void) { + let b = non_null!(&mut *body ?= ()); + b.0.as_ffi_mut().userdata = userdata; + } +} + +ffi_fn! { + /// Set the data callback for this body. + /// + /// The callback is called each time hyper needs to send more data for the + /// body. It is passed the value from `hyper_body_set_userdata`. + /// + /// If there is data available, the `hyper_buf **` argument should be set + /// to a `hyper_buf *` containing the data, and `HYPER_POLL_READY` should + /// be returned. + /// + /// Returning `HYPER_POLL_READY` while the `hyper_buf **` argument points + /// to `NULL` will indicate the body has completed all data. + /// + /// If there is more data to send, but it isn't yet available, a + /// `hyper_waker` should be saved from the `hyper_context *` argument, and + /// `HYPER_POLL_PENDING` should be returned. You must wake the saved waker + /// to signal the task when data is available. + /// + /// If some error has occurred, you can return `HYPER_POLL_ERROR` to abort + /// the body. + fn hyper_body_set_data_func(body: *mut hyper_body, func: hyper_body_data_callback) { + let b = non_null!{ &mut *body ?= () }; + b.0.as_ffi_mut().data_func = func; + } +} + +// ===== impl UserBody ===== + +impl UserBody { + pub(crate) fn new() -> UserBody { + UserBody { + data_func: data_noop, + userdata: std::ptr::null_mut(), + } + } + + pub(crate) fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Option<crate::Result<Bytes>>> { + let mut out = std::ptr::null_mut(); + match (self.data_func)(self.userdata, hyper_context::wrap(cx), &mut out) { + super::task::HYPER_POLL_READY => { + if out.is_null() { + Poll::Ready(None) + } else { + let buf = unsafe { Box::from_raw(out) }; + Poll::Ready(Some(Ok(buf.0))) + } + } + super::task::HYPER_POLL_PENDING => Poll::Pending, + super::task::HYPER_POLL_ERROR => { + Poll::Ready(Some(Err(crate::Error::new_body_write_aborted()))) + } + unexpected => Poll::Ready(Some(Err(crate::Error::new_body_write(format!( + "unexpected hyper_body_data_func return code {}", + unexpected + ))))), + } + } + + pub(crate) fn poll_trailers( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll<crate::Result<Option<HeaderMap>>> { + Poll::Ready(Ok(None)) + } +} + +/// cbindgen:ignore +extern "C" fn data_noop( + _userdata: *mut c_void, + _: *mut hyper_context<'_>, + _: *mut *mut hyper_buf, +) -> c_int { + super::task::HYPER_POLL_READY +} + +unsafe impl Send for UserBody {} +unsafe impl Sync for UserBody {} + +// ===== Bytes ===== + +ffi_fn! { + /// Create a new `hyper_buf *` by copying the provided bytes. + /// + /// This makes an owned copy of the bytes, so the `buf` argument can be + /// freed or changed afterwards. + /// + /// This returns `NULL` if allocating a new buffer fails. + fn hyper_buf_copy(buf: *const u8, len: size_t) -> *mut hyper_buf { + let slice = unsafe { + std::slice::from_raw_parts(buf, len) + }; + Box::into_raw(Box::new(hyper_buf(Bytes::copy_from_slice(slice)))) + } ?= ptr::null_mut() +} + +ffi_fn! { + /// Get a pointer to the bytes in this buffer. + /// + /// This should be used in conjunction with `hyper_buf_len` to get the length + /// of the bytes data. + /// + /// This pointer is borrowed data, and not valid once the `hyper_buf` is + /// consumed/freed. + fn hyper_buf_bytes(buf: *const hyper_buf) -> *const u8 { + unsafe { (*buf).0.as_ptr() } + } ?= ptr::null() +} + +ffi_fn! { + /// Get the length of the bytes this buffer contains. + fn hyper_buf_len(buf: *const hyper_buf) -> size_t { + unsafe { (*buf).0.len() } + } +} + +ffi_fn! { + /// Free this buffer. + fn hyper_buf_free(buf: *mut hyper_buf) { + drop(unsafe { Box::from_raw(buf) }); + } +} + +unsafe impl AsTaskType for hyper_buf { + fn as_task_type(&self) -> hyper_task_return_type { + hyper_task_return_type::HYPER_TASK_BUF + } +} diff --git a/third_party/rust/hyper/src/ffi/client.rs b/third_party/rust/hyper/src/ffi/client.rs new file mode 100644 index 0000000000..4cdb257e30 --- /dev/null +++ b/third_party/rust/hyper/src/ffi/client.rs @@ -0,0 +1,181 @@ +use std::ptr; +use std::sync::Arc; + +use libc::c_int; + +use crate::client::conn; +use crate::rt::Executor as _; + +use super::error::hyper_code; +use super::http_types::{hyper_request, hyper_response}; +use super::io::hyper_io; +use super::task::{hyper_executor, hyper_task, hyper_task_return_type, AsTaskType, WeakExec}; + +/// An options builder to configure an HTTP client connection. +pub struct hyper_clientconn_options { + builder: conn::Builder, + /// Use a `Weak` to prevent cycles. + exec: WeakExec, +} + +/// An HTTP client connection handle. +/// +/// These are used to send a request on a single connection. It's possible to +/// send multiple requests on a single connection, such as when HTTP/1 +/// keep-alive or HTTP/2 is used. +pub struct hyper_clientconn { + tx: conn::SendRequest<crate::Body>, +} + +// ===== impl hyper_clientconn ===== + +ffi_fn! { + /// Starts an HTTP client connection handshake using the provided IO transport + /// and options. + /// + /// Both the `io` and the `options` are consumed in this function call. + /// + /// The returned `hyper_task *` must be polled with an executor until the + /// handshake completes, at which point the value can be taken. + fn hyper_clientconn_handshake(io: *mut hyper_io, options: *mut hyper_clientconn_options) -> *mut hyper_task { + let options = non_null! { Box::from_raw(options) ?= ptr::null_mut() }; + let io = non_null! { Box::from_raw(io) ?= ptr::null_mut() }; + + Box::into_raw(hyper_task::boxed(async move { + options.builder.handshake::<_, crate::Body>(io) + .await + .map(|(tx, conn)| { + options.exec.execute(Box::pin(async move { + let _ = conn.await; + })); + hyper_clientconn { tx } + }) + })) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Send a request on the client connection. + /// + /// Returns a task that needs to be polled until it is ready. When ready, the + /// task yields a `hyper_response *`. + fn hyper_clientconn_send(conn: *mut hyper_clientconn, req: *mut hyper_request) -> *mut hyper_task { + let mut req = non_null! { Box::from_raw(req) ?= ptr::null_mut() }; + + // Update request with original-case map of headers + req.finalize_request(); + + let fut = non_null! { &mut *conn ?= ptr::null_mut() }.tx.send_request(req.0); + + let fut = async move { + fut.await.map(hyper_response::wrap) + }; + + Box::into_raw(hyper_task::boxed(fut)) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Free a `hyper_clientconn *`. + fn hyper_clientconn_free(conn: *mut hyper_clientconn) { + drop(non_null! { Box::from_raw(conn) ?= () }); + } +} + +unsafe impl AsTaskType for hyper_clientconn { + fn as_task_type(&self) -> hyper_task_return_type { + hyper_task_return_type::HYPER_TASK_CLIENTCONN + } +} + +// ===== impl hyper_clientconn_options ===== + +ffi_fn! { + /// Creates a new set of HTTP clientconn options to be used in a handshake. + fn hyper_clientconn_options_new() -> *mut hyper_clientconn_options { + let builder = conn::Builder::new(); + + Box::into_raw(Box::new(hyper_clientconn_options { + builder, + exec: WeakExec::new(), + })) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Set the whether or not header case is preserved. + /// + /// Pass `0` to allow lowercase normalization (default), `1` to retain original case. + fn hyper_clientconn_options_set_preserve_header_case(opts: *mut hyper_clientconn_options, enabled: c_int) { + let opts = non_null! { &mut *opts ?= () }; + opts.builder.http1_preserve_header_case(enabled != 0); + } +} + +ffi_fn! { + /// Set the whether or not header order is preserved. + /// + /// Pass `0` to allow reordering (default), `1` to retain original ordering. + fn hyper_clientconn_options_set_preserve_header_order(opts: *mut hyper_clientconn_options, enabled: c_int) { + let opts = non_null! { &mut *opts ?= () }; + opts.builder.http1_preserve_header_order(enabled != 0); + } +} + +ffi_fn! { + /// Free a `hyper_clientconn_options *`. + fn hyper_clientconn_options_free(opts: *mut hyper_clientconn_options) { + drop(non_null! { Box::from_raw(opts) ?= () }); + } +} + +ffi_fn! { + /// Set the client background task executor. + /// + /// This does not consume the `options` or the `exec`. + fn hyper_clientconn_options_exec(opts: *mut hyper_clientconn_options, exec: *const hyper_executor) { + let opts = non_null! { &mut *opts ?= () }; + + let exec = non_null! { Arc::from_raw(exec) ?= () }; + let weak_exec = hyper_executor::downgrade(&exec); + std::mem::forget(exec); + + opts.builder.executor(weak_exec.clone()); + opts.exec = weak_exec; + } +} + +ffi_fn! { + /// Set the whether to use HTTP2. + /// + /// Pass `0` to disable, `1` to enable. + fn hyper_clientconn_options_http2(opts: *mut hyper_clientconn_options, enabled: c_int) -> hyper_code { + #[cfg(feature = "http2")] + { + let opts = non_null! { &mut *opts ?= hyper_code::HYPERE_INVALID_ARG }; + opts.builder.http2_only(enabled != 0); + hyper_code::HYPERE_OK + } + + #[cfg(not(feature = "http2"))] + { + drop(opts); + drop(enabled); + hyper_code::HYPERE_FEATURE_NOT_ENABLED + } + } +} + +ffi_fn! { + /// Set the whether to include a copy of the raw headers in responses + /// received on this connection. + /// + /// Pass `0` to disable, `1` to enable. + /// + /// If enabled, see `hyper_response_headers_raw()` for usage. + fn hyper_clientconn_options_headers_raw(opts: *mut hyper_clientconn_options, enabled: c_int) -> hyper_code { + let opts = non_null! { &mut *opts ?= hyper_code::HYPERE_INVALID_ARG }; + opts.builder.http1_headers_raw(enabled != 0); + hyper_code::HYPERE_OK + } +} diff --git a/third_party/rust/hyper/src/ffi/error.rs b/third_party/rust/hyper/src/ffi/error.rs new file mode 100644 index 0000000000..015e595aee --- /dev/null +++ b/third_party/rust/hyper/src/ffi/error.rs @@ -0,0 +1,85 @@ +use libc::size_t; + +/// A more detailed error object returned by some hyper functions. +pub struct hyper_error(crate::Error); + +/// A return code for many of hyper's methods. +#[repr(C)] +pub enum hyper_code { + /// All is well. + HYPERE_OK, + /// General error, details in the `hyper_error *`. + HYPERE_ERROR, + /// A function argument was invalid. + HYPERE_INVALID_ARG, + /// The IO transport returned an EOF when one wasn't expected. + /// + /// This typically means an HTTP request or response was expected, but the + /// connection closed cleanly without sending (all of) it. + HYPERE_UNEXPECTED_EOF, + /// Aborted by a user supplied callback. + HYPERE_ABORTED_BY_CALLBACK, + /// An optional hyper feature was not enabled. + #[cfg_attr(feature = "http2", allow(unused))] + HYPERE_FEATURE_NOT_ENABLED, + /// The peer sent an HTTP message that could not be parsed. + HYPERE_INVALID_PEER_MESSAGE, +} + +// ===== impl hyper_error ===== + +impl hyper_error { + fn code(&self) -> hyper_code { + use crate::error::Kind as ErrorKind; + use crate::error::User; + + match self.0.kind() { + ErrorKind::Parse(_) => hyper_code::HYPERE_INVALID_PEER_MESSAGE, + ErrorKind::IncompleteMessage => hyper_code::HYPERE_UNEXPECTED_EOF, + ErrorKind::User(User::AbortedByCallback) => hyper_code::HYPERE_ABORTED_BY_CALLBACK, + // TODO: add more variants + _ => hyper_code::HYPERE_ERROR, + } + } + + fn print_to(&self, dst: &mut [u8]) -> usize { + use std::io::Write; + + let mut dst = std::io::Cursor::new(dst); + + // A write! error doesn't matter. As much as possible will have been + // written, and the Cursor position will know how far that is (even + // if that is zero). + let _ = write!(dst, "{}", &self.0); + dst.position() as usize + } +} + +ffi_fn! { + /// Frees a `hyper_error`. + fn hyper_error_free(err: *mut hyper_error) { + drop(non_null!(Box::from_raw(err) ?= ())); + } +} + +ffi_fn! { + /// Get an equivalent `hyper_code` from this error. + fn hyper_error_code(err: *const hyper_error) -> hyper_code { + non_null!(&*err ?= hyper_code::HYPERE_INVALID_ARG).code() + } +} + +ffi_fn! { + /// Print the details of this error to a buffer. + /// + /// The `dst_len` value must be the maximum length that the buffer can + /// store. + /// + /// The return value is number of bytes that were written to `dst`. + fn hyper_error_print(err: *const hyper_error, dst: *mut u8, dst_len: size_t) -> size_t { + let dst = unsafe { + std::slice::from_raw_parts_mut(dst, dst_len) + }; + non_null!(&*err ?= 0).print_to(dst) + } +} diff --git a/third_party/rust/hyper/src/ffi/http_types.rs b/third_party/rust/hyper/src/ffi/http_types.rs new file mode 100644 index 0000000000..ea10f139cb --- /dev/null +++ b/third_party/rust/hyper/src/ffi/http_types.rs @@ -0,0 +1,657 @@ +use bytes::Bytes; +use libc::{c_int, size_t}; +use std::ffi::c_void; + +use super::body::{hyper_body, hyper_buf}; +use super::error::hyper_code; +use super::task::{hyper_task_return_type, AsTaskType}; +use super::{UserDataPointer, HYPER_ITER_CONTINUE}; +use crate::ext::{HeaderCaseMap, OriginalHeaderOrder, ReasonPhrase}; +use crate::header::{HeaderName, HeaderValue}; +use crate::{Body, HeaderMap, Method, Request, Response, Uri}; + +/// An HTTP request. +pub struct hyper_request(pub(super) Request<Body>); + +/// An HTTP response. +pub struct hyper_response(pub(super) Response<Body>); + +/// An HTTP header map. +/// +/// These can be part of a request or response. +pub struct hyper_headers { + pub(super) headers: HeaderMap, + orig_casing: HeaderCaseMap, + orig_order: OriginalHeaderOrder, +} + +pub(crate) struct RawHeaders(pub(crate) hyper_buf); + +pub(crate) struct OnInformational { + func: hyper_request_on_informational_callback, + data: UserDataPointer, +} + +type hyper_request_on_informational_callback = extern "C" fn(*mut c_void, *mut hyper_response); + +// ===== impl hyper_request ===== + +ffi_fn! { + /// Construct a new HTTP request. + fn hyper_request_new() -> *mut hyper_request { + Box::into_raw(Box::new(hyper_request(Request::new(Body::empty())))) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Free an HTTP request if not going to send it on a client. + fn hyper_request_free(req: *mut hyper_request) { + drop(non_null!(Box::from_raw(req) ?= ())); + } +} + +ffi_fn! { + /// Set the HTTP Method of the request. + fn hyper_request_set_method(req: *mut hyper_request, method: *const u8, method_len: size_t) -> hyper_code { + let bytes = unsafe { + std::slice::from_raw_parts(method, method_len as usize) + }; + let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); + match Method::from_bytes(bytes) { + Ok(m) => { + *req.0.method_mut() = m; + hyper_code::HYPERE_OK + }, + Err(_) => { + hyper_code::HYPERE_INVALID_ARG + } + } + } +} + +ffi_fn! { + /// Set the URI of the request. + /// + /// The request's URI is best described as the `request-target` from the RFCs. So in HTTP/1, + /// whatever is set will get sent as-is in the first line (GET $uri HTTP/1.1). It + /// supports the 4 defined variants, origin-form, absolute-form, authority-form, and + /// asterisk-form. + /// + /// The underlying type was built to efficiently support HTTP/2 where the request-target is + /// split over :scheme, :authority, and :path. As such, each part can be set explicitly, or the + /// type can parse a single contiguous string and if a scheme is found, that slot is "set". If + /// the string just starts with a path, only the path portion is set. All pseudo headers that + /// have been parsed/set are sent when the connection type is HTTP/2. + /// + /// To set each slot explicitly, use `hyper_request_set_uri_parts`. + fn hyper_request_set_uri(req: *mut hyper_request, uri: *const u8, uri_len: size_t) -> hyper_code { + let bytes = unsafe { + std::slice::from_raw_parts(uri, uri_len as usize) + }; + let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); + match Uri::from_maybe_shared(bytes) { + Ok(u) => { + *req.0.uri_mut() = u; + hyper_code::HYPERE_OK + }, + Err(_) => { + hyper_code::HYPERE_INVALID_ARG + } + } + } +} + +ffi_fn! { + /// Set the URI of the request with separate scheme, authority, and + /// path/query strings. + /// + /// Each of `scheme`, `authority`, and `path_and_query` should either be + /// null, to skip providing a component, or point to a UTF-8 encoded + /// string. If any string pointer argument is non-null, its corresponding + /// `len` parameter must be set to the string's length. + fn hyper_request_set_uri_parts( + req: *mut hyper_request, + scheme: *const u8, + scheme_len: size_t, + authority: *const u8, + authority_len: size_t, + path_and_query: *const u8, + path_and_query_len: size_t + ) -> hyper_code { + let mut builder = Uri::builder(); + if !scheme.is_null() { + let scheme_bytes = unsafe { + std::slice::from_raw_parts(scheme, scheme_len as usize) + }; + builder = builder.scheme(scheme_bytes); + } + if !authority.is_null() { + let authority_bytes = unsafe { + std::slice::from_raw_parts(authority, authority_len as usize) + }; + builder = builder.authority(authority_bytes); + } + if !path_and_query.is_null() { + let path_and_query_bytes = unsafe { + std::slice::from_raw_parts(path_and_query, path_and_query_len as usize) + }; + builder = builder.path_and_query(path_and_query_bytes); + } + match builder.build() { + Ok(u) => { + *unsafe { &mut *req }.0.uri_mut() = u; + hyper_code::HYPERE_OK + }, + Err(_) => { + hyper_code::HYPERE_INVALID_ARG + } + } + } +} + +ffi_fn! { + /// Set the preferred HTTP version of the request. + /// + /// The version value should be one of the `HYPER_HTTP_VERSION_` constants. + /// + /// Note that this won't change the major HTTP version of the connection, + /// since that is determined at the handshake step. + fn hyper_request_set_version(req: *mut hyper_request, version: c_int) -> hyper_code { + use http::Version; + + let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); + *req.0.version_mut() = match version { + super::HYPER_HTTP_VERSION_NONE => Version::HTTP_11, + super::HYPER_HTTP_VERSION_1_0 => Version::HTTP_10, + super::HYPER_HTTP_VERSION_1_1 => Version::HTTP_11, + super::HYPER_HTTP_VERSION_2 => Version::HTTP_2, + _ => { + // We don't know this version + return hyper_code::HYPERE_INVALID_ARG; + } + }; + hyper_code::HYPERE_OK + } +} + +ffi_fn! { + /// Gets a reference to the HTTP headers of this request + /// + /// This is not an owned reference, so it should not be accessed after the + /// `hyper_request` has been consumed. + fn hyper_request_headers(req: *mut hyper_request) -> *mut hyper_headers { + hyper_headers::get_or_default(unsafe { &mut *req }.0.extensions_mut()) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Set the body of the request. + /// + /// The default is an empty body. + /// + /// This takes ownership of the `hyper_body *`, you must not use it or + /// free it after setting it on the request. + fn hyper_request_set_body(req: *mut hyper_request, body: *mut hyper_body) -> hyper_code { + let body = non_null!(Box::from_raw(body) ?= hyper_code::HYPERE_INVALID_ARG); + let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); + *req.0.body_mut() = body.0; + hyper_code::HYPERE_OK + } +} + +ffi_fn! { + /// Set an informational (1xx) response callback. + /// + /// The callback is called each time hyper receives an informational (1xx) + /// response for this request. + /// + /// The third argument is an opaque user data pointer, which is passed to + /// the callback each time. + /// + /// The callback is passed the `void *` data pointer, and a + /// `hyper_response *` which can be inspected as any other response. The + /// body of the response will always be empty. + /// + /// NOTE: The `hyper_response *` is just borrowed data, and will not + /// be valid after the callback finishes. You must copy any data you wish + /// to persist. + fn hyper_request_on_informational(req: *mut hyper_request, callback: hyper_request_on_informational_callback, data: *mut c_void) -> hyper_code { + let ext = OnInformational { + func: callback, + data: UserDataPointer(data), + }; + let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); + req.0.extensions_mut().insert(ext); + hyper_code::HYPERE_OK + } +} + +impl hyper_request { + pub(super) fn finalize_request(&mut self) { + if let Some(headers) = self.0.extensions_mut().remove::<hyper_headers>() { + *self.0.headers_mut() = headers.headers; + self.0.extensions_mut().insert(headers.orig_casing); + self.0.extensions_mut().insert(headers.orig_order); + } + } +} + +// ===== impl hyper_response ===== + +ffi_fn! { + /// Free an HTTP response after using it. + fn hyper_response_free(resp: *mut hyper_response) { + drop(non_null!(Box::from_raw(resp) ?= ())); + } +} + +ffi_fn! { + /// Get the HTTP-Status code of this response. + /// + /// It will always be within the range of 100-599. + fn hyper_response_status(resp: *const hyper_response) -> u16 { + non_null!(&*resp ?= 0).0.status().as_u16() + } +} + +ffi_fn! { + /// Get a pointer to the reason-phrase of this response. + /// + /// This buffer is not null-terminated. + /// + /// This buffer is owned by the response, and should not be used after + /// the response has been freed. + /// + /// Use `hyper_response_reason_phrase_len()` to get the length of this + /// buffer. + fn hyper_response_reason_phrase(resp: *const hyper_response) -> *const u8 { + non_null!(&*resp ?= std::ptr::null()).reason_phrase().as_ptr() + } ?= std::ptr::null() +} + +ffi_fn! { + /// Get the length of the reason-phrase of this response. + /// + /// Use `hyper_response_reason_phrase()` to get the buffer pointer. + fn hyper_response_reason_phrase_len(resp: *const hyper_response) -> size_t { + non_null!(&*resp ?= 0).reason_phrase().len() + } +} + +ffi_fn! { + /// Get a reference to the full raw headers of this response. + /// + /// You must have enabled `hyper_clientconn_options_headers_raw()`, or this + /// will return NULL. + /// + /// The returned `hyper_buf *` is just a reference, owned by the response. + /// You need to make a copy if you wish to use it after freeing the + /// response. + /// + /// The buffer is not null-terminated, see the `hyper_buf` functions for + /// getting the bytes and length. + fn hyper_response_headers_raw(resp: *const hyper_response) -> *const hyper_buf { + let resp = non_null!(&*resp ?= std::ptr::null()); + match resp.0.extensions().get::<RawHeaders>() { + Some(raw) => &raw.0, + None => std::ptr::null(), + } + } ?= std::ptr::null() +} + +ffi_fn! { + /// Get the HTTP version used by this response. + /// + /// The returned value could be: + /// + /// - `HYPER_HTTP_VERSION_1_0` + /// - `HYPER_HTTP_VERSION_1_1` + /// - `HYPER_HTTP_VERSION_2` + /// - `HYPER_HTTP_VERSION_NONE` if newer (or older). + fn hyper_response_version(resp: *const hyper_response) -> c_int { + use http::Version; + + match non_null!(&*resp ?= 0).0.version() { + Version::HTTP_10 => super::HYPER_HTTP_VERSION_1_0, + Version::HTTP_11 => super::HYPER_HTTP_VERSION_1_1, + Version::HTTP_2 => super::HYPER_HTTP_VERSION_2, + _ => super::HYPER_HTTP_VERSION_NONE, + } + } +} + +ffi_fn! { + /// Gets a reference to the HTTP headers of this response. + /// + /// This is not an owned reference, so it should not be accessed after the + /// `hyper_response` has been freed. + fn hyper_response_headers(resp: *mut hyper_response) -> *mut hyper_headers { + hyper_headers::get_or_default(unsafe { &mut *resp }.0.extensions_mut()) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Take ownership of the body of this response. + /// + /// It is safe to free the response even after taking ownership of its body. + fn hyper_response_body(resp: *mut hyper_response) -> *mut hyper_body { + let body = std::mem::take(non_null!(&mut *resp ?= std::ptr::null_mut()).0.body_mut()); + Box::into_raw(Box::new(hyper_body(body))) + } ?= std::ptr::null_mut() +} + +impl hyper_response { + pub(super) fn wrap(mut resp: Response<Body>) -> hyper_response { + let headers = std::mem::take(resp.headers_mut()); + let orig_casing = resp + .extensions_mut() + .remove::<HeaderCaseMap>() + .unwrap_or_else(HeaderCaseMap::default); + let orig_order = resp + .extensions_mut() + .remove::<OriginalHeaderOrder>() + .unwrap_or_else(OriginalHeaderOrder::default); + resp.extensions_mut().insert(hyper_headers { + headers, + orig_casing, + orig_order, + }); + + hyper_response(resp) + } + + fn reason_phrase(&self) -> &[u8] { + if let Some(reason) = self.0.extensions().get::<ReasonPhrase>() { + return reason.as_bytes(); + } + + if let Some(reason) = self.0.status().canonical_reason() { + return reason.as_bytes(); + } + + &[] + } +} + +unsafe impl AsTaskType for hyper_response { + fn as_task_type(&self) -> hyper_task_return_type { + hyper_task_return_type::HYPER_TASK_RESPONSE + } +} + +// ===== impl Headers ===== + +type hyper_headers_foreach_callback = + extern "C" fn(*mut c_void, *const u8, size_t, *const u8, size_t) -> c_int; + +impl hyper_headers { + pub(super) fn get_or_default(ext: &mut http::Extensions) -> &mut hyper_headers { + if let None = ext.get_mut::<hyper_headers>() { + ext.insert(hyper_headers::default()); + } + + ext.get_mut::<hyper_headers>().unwrap() + } +} + +ffi_fn! { + /// Iterates the headers passing each name and value pair to the callback. + /// + /// The `userdata` pointer is also passed to the callback. + /// + /// The callback should return `HYPER_ITER_CONTINUE` to keep iterating, or + /// `HYPER_ITER_BREAK` to stop. + fn hyper_headers_foreach(headers: *const hyper_headers, func: hyper_headers_foreach_callback, userdata: *mut c_void) { + let headers = non_null!(&*headers ?= ()); + // For each header name/value pair, there may be a value in the casemap + // that corresponds to the HeaderValue. So, we iterator all the keys, + // and for each one, try to pair the originally cased name with the value. + // + // TODO: consider adding http::HeaderMap::entries() iterator + let mut ordered_iter = headers.orig_order.get_in_order().peekable(); + if ordered_iter.peek().is_some() { + for (name, idx) in ordered_iter { + let (name_ptr, name_len) = if let Some(orig_name) = headers.orig_casing.get_all(name).nth(*idx) { + (orig_name.as_ref().as_ptr(), orig_name.as_ref().len()) + } else { + ( + name.as_str().as_bytes().as_ptr(), + name.as_str().as_bytes().len(), + ) + }; + + let val_ptr; + let val_len; + if let Some(value) = headers.headers.get_all(name).iter().nth(*idx) { + val_ptr = value.as_bytes().as_ptr(); + val_len = value.as_bytes().len(); + } else { + // Stop iterating, something has gone wrong. + return; + } + + if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) { + return; + } + } + } else { + for name in headers.headers.keys() { + let mut names = headers.orig_casing.get_all(name); + + for value in headers.headers.get_all(name) { + let (name_ptr, name_len) = if let Some(orig_name) = names.next() { + (orig_name.as_ref().as_ptr(), orig_name.as_ref().len()) + } else { + ( + name.as_str().as_bytes().as_ptr(), + name.as_str().as_bytes().len(), + ) + }; + + let val_ptr = value.as_bytes().as_ptr(); + let val_len = value.as_bytes().len(); + + if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) { + return; + } + } + } + } + } +} + +ffi_fn! { + /// Sets the header with the provided name to the provided value. + /// + /// This overwrites any previous value set for the header. + fn hyper_headers_set(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code { + let headers = non_null!(&mut *headers ?= hyper_code::HYPERE_INVALID_ARG); + match unsafe { raw_name_value(name, name_len, value, value_len) } { + Ok((name, value, orig_name)) => { + headers.headers.insert(&name, value); + headers.orig_casing.insert(name.clone(), orig_name.clone()); + headers.orig_order.insert(name); + hyper_code::HYPERE_OK + } + Err(code) => code, + } + } +} + +ffi_fn! { + /// Adds the provided value to the list of the provided name. + /// + /// If there were already existing values for the name, this will append the + /// new value to the internal list. + fn hyper_headers_add(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code { + let headers = non_null!(&mut *headers ?= hyper_code::HYPERE_INVALID_ARG); + + match unsafe { raw_name_value(name, name_len, value, value_len) } { + Ok((name, value, orig_name)) => { + headers.headers.append(&name, value); + headers.orig_casing.append(&name, orig_name.clone()); + headers.orig_order.append(name); + hyper_code::HYPERE_OK + } + Err(code) => code, + } + } +} + +impl Default for hyper_headers { + fn default() -> Self { + Self { + headers: Default::default(), + orig_casing: HeaderCaseMap::default(), + orig_order: OriginalHeaderOrder::default(), + } + } +} + +unsafe fn raw_name_value( + name: *const u8, + name_len: size_t, + value: *const u8, + value_len: size_t, +) -> Result<(HeaderName, HeaderValue, Bytes), hyper_code> { + let name = std::slice::from_raw_parts(name, name_len); + let orig_name = Bytes::copy_from_slice(name); + let name = match HeaderName::from_bytes(name) { + Ok(name) => name, + Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG), + }; + let value = std::slice::from_raw_parts(value, value_len); + let value = match HeaderValue::from_bytes(value) { + Ok(val) => val, + Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG), + }; + + Ok((name, value, orig_name)) +} + +// ===== impl OnInformational ===== + +impl OnInformational { + pub(crate) fn call(&mut self, resp: Response<Body>) { + let mut resp = hyper_response::wrap(resp); + (self.func)(self.data.0, &mut resp); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_headers_foreach_cases_preserved() { + let mut headers = hyper_headers::default(); + + let name1 = b"Set-CookiE"; + let value1 = b"a=b"; + hyper_headers_add( + &mut headers, + name1.as_ptr(), + name1.len(), + value1.as_ptr(), + value1.len(), + ); + + let name2 = b"SET-COOKIE"; + let value2 = b"c=d"; + hyper_headers_add( + &mut headers, + name2.as_ptr(), + name2.len(), + value2.as_ptr(), + value2.len(), + ); + + let mut vec = Vec::<u8>::new(); + hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void); + + assert_eq!(vec, b"Set-CookiE: a=b\r\nSET-COOKIE: c=d\r\n"); + + extern "C" fn concat( + vec: *mut c_void, + name: *const u8, + name_len: usize, + value: *const u8, + value_len: usize, + ) -> c_int { + unsafe { + let vec = &mut *(vec as *mut Vec<u8>); + let name = std::slice::from_raw_parts(name, name_len); + let value = std::slice::from_raw_parts(value, value_len); + vec.extend(name); + vec.extend(b": "); + vec.extend(value); + vec.extend(b"\r\n"); + } + HYPER_ITER_CONTINUE + } + } + + #[cfg(all(feature = "http1", feature = "ffi"))] + #[test] + fn test_headers_foreach_order_preserved() { + let mut headers = hyper_headers::default(); + + let name1 = b"Set-CookiE"; + let value1 = b"a=b"; + hyper_headers_add( + &mut headers, + name1.as_ptr(), + name1.len(), + value1.as_ptr(), + value1.len(), + ); + + let name2 = b"Content-Encoding"; + let value2 = b"gzip"; + hyper_headers_add( + &mut headers, + name2.as_ptr(), + name2.len(), + value2.as_ptr(), + value2.len(), + ); + + let name3 = b"SET-COOKIE"; + let value3 = b"c=d"; + hyper_headers_add( + &mut headers, + name3.as_ptr(), + name3.len(), + value3.as_ptr(), + value3.len(), + ); + + let mut vec = Vec::<u8>::new(); + hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void); + + println!("{}", std::str::from_utf8(&vec).unwrap()); + assert_eq!( + vec, + b"Set-CookiE: a=b\r\nContent-Encoding: gzip\r\nSET-COOKIE: c=d\r\n" + ); + + extern "C" fn concat( + vec: *mut c_void, + name: *const u8, + name_len: usize, + value: *const u8, + value_len: usize, + ) -> c_int { + unsafe { + let vec = &mut *(vec as *mut Vec<u8>); + let name = std::slice::from_raw_parts(name, name_len); + let value = std::slice::from_raw_parts(value, value_len); + vec.extend(name); + vec.extend(b": "); + vec.extend(value); + vec.extend(b"\r\n"); + } + HYPER_ITER_CONTINUE + } + } +} diff --git a/third_party/rust/hyper/src/ffi/io.rs b/third_party/rust/hyper/src/ffi/io.rs new file mode 100644 index 0000000000..bff666dbcf --- /dev/null +++ b/third_party/rust/hyper/src/ffi/io.rs @@ -0,0 +1,178 @@ +use std::ffi::c_void; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use libc::size_t; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::task::hyper_context; + +/// Sentinel value to return from a read or write callback that the operation +/// is pending. +pub const HYPER_IO_PENDING: size_t = 0xFFFFFFFF; +/// Sentinel value to return from a read or write callback that the operation +/// has errored. +pub const HYPER_IO_ERROR: size_t = 0xFFFFFFFE; + +type hyper_io_read_callback = + extern "C" fn(*mut c_void, *mut hyper_context<'_>, *mut u8, size_t) -> size_t; +type hyper_io_write_callback = + extern "C" fn(*mut c_void, *mut hyper_context<'_>, *const u8, size_t) -> size_t; + +/// An IO object used to represent a socket or similar concept. +pub struct hyper_io { + read: hyper_io_read_callback, + write: hyper_io_write_callback, + userdata: *mut c_void, +} + +ffi_fn! { + /// Create a new IO type used to represent a transport. + /// + /// The read and write functions of this transport should be set with + /// `hyper_io_set_read` and `hyper_io_set_write`. + fn hyper_io_new() -> *mut hyper_io { + Box::into_raw(Box::new(hyper_io { + read: read_noop, + write: write_noop, + userdata: std::ptr::null_mut(), + })) + } ?= std::ptr::null_mut() +} + +ffi_fn! { + /// Free an unused `hyper_io *`. + /// + /// This is typically only useful if you aren't going to pass ownership + /// of the IO handle to hyper, such as with `hyper_clientconn_handshake()`. + fn hyper_io_free(io: *mut hyper_io) { + drop(non_null!(Box::from_raw(io) ?= ())); + } +} + +ffi_fn! { + /// Set the user data pointer for this IO to some value. + /// + /// This value is passed as an argument to the read and write callbacks. + fn hyper_io_set_userdata(io: *mut hyper_io, data: *mut c_void) { + non_null!(&mut *io ?= ()).userdata = data; + } +} + +ffi_fn! { + /// Set the read function for this IO transport. + /// + /// Data that is read from the transport should be put in the `buf` pointer, + /// up to `buf_len` bytes. The number of bytes read should be the return value. + /// + /// It is undefined behavior to try to access the bytes in the `buf` pointer, + /// unless you have already written them yourself. It is also undefined behavior + /// to return that more bytes have been written than actually set on the `buf`. + /// + /// If there is no data currently available, a waker should be claimed from + /// the `ctx` and registered with whatever polling mechanism is used to signal + /// when data is available later on. The return value should be + /// `HYPER_IO_PENDING`. + /// + /// If there is an irrecoverable error reading data, then `HYPER_IO_ERROR` + /// should be the return value. + fn hyper_io_set_read(io: *mut hyper_io, func: hyper_io_read_callback) { + non_null!(&mut *io ?= ()).read = func; + } +} + +ffi_fn! { + /// Set the write function for this IO transport. + /// + /// Data from the `buf` pointer should be written to the transport, up to + /// `buf_len` bytes. The number of bytes written should be the return value. + /// + /// If no data can currently be written, the `waker` should be cloned and + /// registered with whatever polling mechanism is used to signal when data + /// is available later on. The return value should be `HYPER_IO_PENDING`. + /// + /// Yeet. + /// + /// If there is an irrecoverable error reading data, then `HYPER_IO_ERROR` + /// should be the return value. + fn hyper_io_set_write(io: *mut hyper_io, func: hyper_io_write_callback) { + non_null!(&mut *io ?= ()).write = func; + } +} + +/// cbindgen:ignore +extern "C" fn read_noop( + _userdata: *mut c_void, + _: *mut hyper_context<'_>, + _buf: *mut u8, + _buf_len: size_t, +) -> size_t { + 0 +} + +/// cbindgen:ignore +extern "C" fn write_noop( + _userdata: *mut c_void, + _: *mut hyper_context<'_>, + _buf: *const u8, + _buf_len: size_t, +) -> size_t { + 0 +} + +impl AsyncRead for hyper_io { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + let buf_ptr = unsafe { buf.unfilled_mut() }.as_mut_ptr() as *mut u8; + let buf_len = buf.remaining(); + + match (self.read)(self.userdata, hyper_context::wrap(cx), buf_ptr, buf_len) { + HYPER_IO_PENDING => Poll::Pending, + HYPER_IO_ERROR => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "io error", + ))), + ok => { + // We have to trust that the user's read callback actually + // filled in that many bytes... :( + unsafe { buf.assume_init(ok) }; + buf.advance(ok); + Poll::Ready(Ok(())) + } + } + } +} + +impl AsyncWrite for hyper_io { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + let buf_ptr = buf.as_ptr(); + let buf_len = buf.len(); + + match (self.write)(self.userdata, hyper_context::wrap(cx), buf_ptr, buf_len) { + HYPER_IO_PENDING => Poll::Pending, + HYPER_IO_ERROR => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "io error", + ))), + ok => Poll::Ready(Ok(ok)), + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } +} + +unsafe impl Send for hyper_io {} +unsafe impl Sync for hyper_io {} diff --git a/third_party/rust/hyper/src/ffi/macros.rs b/third_party/rust/hyper/src/ffi/macros.rs new file mode 100644 index 0000000000..022711baaa --- /dev/null +++ b/third_party/rust/hyper/src/ffi/macros.rs @@ -0,0 +1,53 @@ +macro_rules! ffi_fn { + ($(#[$doc:meta])* fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty $body:block ?= $default:expr) => { + $(#[$doc])* + #[no_mangle] + pub extern fn $name($($arg: $arg_ty),*) -> $ret { + use std::panic::{self, AssertUnwindSafe}; + + match panic::catch_unwind(AssertUnwindSafe(move || $body)) { + Ok(v) => v, + Err(_) => { + $default + } + } + } + }; + + ($(#[$doc:meta])* fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty $body:block) => { + ffi_fn!($(#[$doc])* fn $name($($arg: $arg_ty),*) -> $ret $body ?= { + eprintln!("panic unwind caught, aborting"); + std::process::abort() + }); + }; + + ($(#[$doc:meta])* fn $name:ident($($arg:ident: $arg_ty:ty),*) $body:block ?= $default:expr) => { + ffi_fn!($(#[$doc])* fn $name($($arg: $arg_ty),*) -> () $body ?= $default); + }; + + ($(#[$doc:meta])* fn $name:ident($($arg:ident: $arg_ty:ty),*) $body:block) => { + ffi_fn!($(#[$doc])* fn $name($($arg: $arg_ty),*) -> () $body); + }; +} + +macro_rules! non_null { + ($ptr:ident, $eval:expr, $err:expr) => {{ + debug_assert!(!$ptr.is_null(), "{:?} must not be null", stringify!($ptr)); + if $ptr.is_null() { + return $err; + } + unsafe { $eval } + }}; + (&*$ptr:ident ?= $err:expr) => {{ + non_null!($ptr, &*$ptr, $err) + }}; + (&mut *$ptr:ident ?= $err:expr) => {{ + non_null!($ptr, &mut *$ptr, $err) + }}; + (Box::from_raw($ptr:ident) ?= $err:expr) => {{ + non_null!($ptr, Box::from_raw($ptr), $err) + }}; + (Arc::from_raw($ptr:ident) ?= $err:expr) => {{ + non_null!($ptr, Arc::from_raw($ptr), $err) + }}; +} diff --git a/third_party/rust/hyper/src/ffi/mod.rs b/third_party/rust/hyper/src/ffi/mod.rs new file mode 100644 index 0000000000..fd67a880a6 --- /dev/null +++ b/third_party/rust/hyper/src/ffi/mod.rs @@ -0,0 +1,94 @@ +// We have a lot of c-types in here, stop warning about their names! +#![allow(non_camel_case_types)] +// fmt::Debug isn't helpful on FFI types +#![allow(missing_debug_implementations)] +// unreachable_pub warns `#[no_mangle] pub extern fn` in private mod. +#![allow(unreachable_pub)] + +//! # hyper C API +//! +//! This part of the documentation describes the C API for hyper. That is, how +//! to *use* the hyper library in C code. This is **not** a regular Rust +//! module, and thus it is not accessible in Rust. +//! +//! ## Unstable +//! +//! The C API of hyper is currently **unstable**, which means it's not part of +//! the semver contract as the rest of the Rust API is. Because of that, it's +//! only accessible if `--cfg hyper_unstable_ffi` is passed to `rustc` when +//! compiling. The easiest way to do that is setting the `RUSTFLAGS` +//! environment variable. +//! +//! ## Building +//! +//! The C API is part of the Rust library, but isn't compiled by default. Using +//! `cargo`, it can be compiled with the following command: +//! +//! ```notrust +//! RUSTFLAGS="--cfg hyper_unstable_ffi" cargo build --features client,http1,http2,ffi +//! ``` + +// We may eventually allow the FFI to be enabled without `client` or `http1`, +// that is why we don't auto enable them as `ffi = ["client", "http1"]` in +// the `Cargo.toml`. +// +// But for now, give a clear message that this compile error is expected. +#[cfg(not(all(feature = "client", feature = "http1")))] +compile_error!("The `ffi` feature currently requires the `client` and `http1` features."); + +#[cfg(not(hyper_unstable_ffi))] +compile_error!( + "\ + The `ffi` feature is unstable, and requires the \ + `RUSTFLAGS='--cfg hyper_unstable_ffi'` environment variable to be set.\ +" +); + +#[macro_use] +mod macros; + +mod body; +mod client; +mod error; +mod http_types; +mod io; +mod task; + +pub use self::body::*; +pub use self::client::*; +pub use self::error::*; +pub use self::http_types::*; +pub use self::io::*; +pub use self::task::*; + +/// Return in iter functions to continue iterating. +pub const HYPER_ITER_CONTINUE: libc::c_int = 0; +/// Return in iter functions to stop iterating. +#[allow(unused)] +pub const HYPER_ITER_BREAK: libc::c_int = 1; + +/// An HTTP Version that is unspecified. +pub const HYPER_HTTP_VERSION_NONE: libc::c_int = 0; +/// The HTTP/1.0 version. +pub const HYPER_HTTP_VERSION_1_0: libc::c_int = 10; +/// The HTTP/1.1 version. +pub const HYPER_HTTP_VERSION_1_1: libc::c_int = 11; +/// The HTTP/2 version. +pub const HYPER_HTTP_VERSION_2: libc::c_int = 20; + +struct UserDataPointer(*mut std::ffi::c_void); + +// We don't actually know anything about this pointer, it's up to the user +// to do the right thing. +unsafe impl Send for UserDataPointer {} +unsafe impl Sync for UserDataPointer {} + +/// cbindgen:ignore +static VERSION_CSTR: &str = concat!(env!("CARGO_PKG_VERSION"), "\0"); + +ffi_fn! { + /// Returns a static ASCII (null terminated) string of the hyper version. + fn hyper_version() -> *const libc::c_char { + VERSION_CSTR.as_ptr() as _ + } ?= std::ptr::null() +} diff --git a/third_party/rust/hyper/src/ffi/task.rs b/third_party/rust/hyper/src/ffi/task.rs new file mode 100644 index 0000000000..ef54fe408f --- /dev/null +++ b/third_party/rust/hyper/src/ffi/task.rs @@ -0,0 +1,411 @@ +use std::ffi::c_void; +use std::future::Future; +use std::pin::Pin; +use std::ptr; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, Weak, +}; +use std::task::{Context, Poll}; + +use futures_util::stream::{FuturesUnordered, Stream}; +use libc::c_int; + +use super::error::hyper_code; +use super::UserDataPointer; + +type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>; +type BoxAny = Box<dyn AsTaskType + Send + Sync>; + +/// Return in a poll function to indicate it was ready. +pub const HYPER_POLL_READY: c_int = 0; +/// Return in a poll function to indicate it is still pending. +/// +/// The passed in `hyper_waker` should be registered to wake up the task at +/// some later point. +pub const HYPER_POLL_PENDING: c_int = 1; +/// Return in a poll function indicate an error. +pub const HYPER_POLL_ERROR: c_int = 3; + +/// A task executor for `hyper_task`s. +pub struct hyper_executor { + /// The executor of all task futures. + /// + /// There should never be contention on the mutex, as it is only locked + /// to drive the futures. However, we cannot guarantee proper usage from + /// `hyper_executor_poll()`, which in C could potentially be called inside + /// one of the stored futures. The mutex isn't re-entrant, so doing so + /// would result in a deadlock, but that's better than data corruption. + driver: Mutex<FuturesUnordered<TaskFuture>>, + + /// The queue of futures that need to be pushed into the `driver`. + /// + /// This is has a separate mutex since `spawn` could be called from inside + /// a future, which would mean the driver's mutex is already locked. + spawn_queue: Mutex<Vec<TaskFuture>>, + + /// This is used to track when a future calls `wake` while we are within + /// `hyper_executor::poll_next`. + is_woken: Arc<ExecWaker>, +} + +#[derive(Clone)] +pub(crate) struct WeakExec(Weak<hyper_executor>); + +struct ExecWaker(AtomicBool); + +/// An async task. +pub struct hyper_task { + future: BoxFuture<BoxAny>, + output: Option<BoxAny>, + userdata: UserDataPointer, +} + +struct TaskFuture { + task: Option<Box<hyper_task>>, +} + +/// An async context for a task that contains the related waker. +pub struct hyper_context<'a>(Context<'a>); + +/// A waker that is saved and used to waken a pending task. +pub struct hyper_waker { + waker: std::task::Waker, +} + +/// A descriptor for what type a `hyper_task` value is. +#[repr(C)] +pub enum hyper_task_return_type { + /// The value of this task is null (does not imply an error). + HYPER_TASK_EMPTY, + /// The value of this task is `hyper_error *`. + HYPER_TASK_ERROR, + /// The value of this task is `hyper_clientconn *`. + HYPER_TASK_CLIENTCONN, + /// The value of this task is `hyper_response *`. + HYPER_TASK_RESPONSE, + /// The value of this task is `hyper_buf *`. + HYPER_TASK_BUF, +} + +pub(crate) unsafe trait AsTaskType { + fn as_task_type(&self) -> hyper_task_return_type; +} + +pub(crate) trait IntoDynTaskType { + fn into_dyn_task_type(self) -> BoxAny; +} + +// ===== impl hyper_executor ===== + +impl hyper_executor { + fn new() -> Arc<hyper_executor> { + Arc::new(hyper_executor { + driver: Mutex::new(FuturesUnordered::new()), + spawn_queue: Mutex::new(Vec::new()), + is_woken: Arc::new(ExecWaker(AtomicBool::new(false))), + }) + } + + pub(crate) fn downgrade(exec: &Arc<hyper_executor>) -> WeakExec { + WeakExec(Arc::downgrade(exec)) + } + + fn spawn(&self, task: Box<hyper_task>) { + self.spawn_queue + .lock() + .unwrap() + .push(TaskFuture { task: Some(task) }); + } + + fn poll_next(&self) -> Option<Box<hyper_task>> { + // Drain the queue first. + self.drain_queue(); + + let waker = futures_util::task::waker_ref(&self.is_woken); + let mut cx = Context::from_waker(&waker); + + loop { + match Pin::new(&mut *self.driver.lock().unwrap()).poll_next(&mut cx) { + Poll::Ready(val) => return val, + Poll::Pending => { + // Check if any of the pending tasks tried to spawn + // some new tasks. If so, drain into the driver and loop. + if self.drain_queue() { + continue; + } + + // If the driver called `wake` while we were polling, + // we should poll again immediately! + if self.is_woken.0.swap(false, Ordering::SeqCst) { + continue; + } + + return None; + } + } + } + } + + fn drain_queue(&self) -> bool { + let mut queue = self.spawn_queue.lock().unwrap(); + if queue.is_empty() { + return false; + } + + let driver = self.driver.lock().unwrap(); + + for task in queue.drain(..) { + driver.push(task); + } + + true + } +} + +impl futures_util::task::ArcWake for ExecWaker { + fn wake_by_ref(me: &Arc<ExecWaker>) { + me.0.store(true, Ordering::SeqCst); + } +} + +// ===== impl WeakExec ===== + +impl WeakExec { + pub(crate) fn new() -> Self { + WeakExec(Weak::new()) + } +} + +impl crate::rt::Executor<BoxFuture<()>> for WeakExec { + fn execute(&self, fut: BoxFuture<()>) { + if let Some(exec) = self.0.upgrade() { + exec.spawn(hyper_task::boxed(fut)); + } + } +} + +ffi_fn! { + /// Creates a new task executor. + fn hyper_executor_new() -> *const hyper_executor { + Arc::into_raw(hyper_executor::new()) + } ?= ptr::null() +} + +ffi_fn! { + /// Frees an executor and any incomplete tasks still part of it. + fn hyper_executor_free(exec: *const hyper_executor) { + drop(non_null!(Arc::from_raw(exec) ?= ())); + } +} + +ffi_fn! { + /// Push a task onto the executor. + /// + /// The executor takes ownership of the task, it should not be accessed + /// again unless returned back to the user with `hyper_executor_poll`. + fn hyper_executor_push(exec: *const hyper_executor, task: *mut hyper_task) -> hyper_code { + let exec = non_null!(&*exec ?= hyper_code::HYPERE_INVALID_ARG); + let task = non_null!(Box::from_raw(task) ?= hyper_code::HYPERE_INVALID_ARG); + exec.spawn(task); + hyper_code::HYPERE_OK + } +} + +ffi_fn! { + /// Polls the executor, trying to make progress on any tasks that have notified + /// that they are ready again. + /// + /// If ready, returns a task from the executor that has completed. + /// + /// If there are no ready tasks, this returns `NULL`. + fn hyper_executor_poll(exec: *const hyper_executor) -> *mut hyper_task { + let exec = non_null!(&*exec ?= ptr::null_mut()); + match exec.poll_next() { + Some(task) => Box::into_raw(task), + None => ptr::null_mut(), + } + } ?= ptr::null_mut() +} + +// ===== impl hyper_task ===== + +impl hyper_task { + pub(crate) fn boxed<F>(fut: F) -> Box<hyper_task> + where + F: Future + Send + 'static, + F::Output: IntoDynTaskType + Send + Sync + 'static, + { + Box::new(hyper_task { + future: Box::pin(async move { fut.await.into_dyn_task_type() }), + output: None, + userdata: UserDataPointer(ptr::null_mut()), + }) + } + + fn output_type(&self) -> hyper_task_return_type { + match self.output { + None => hyper_task_return_type::HYPER_TASK_EMPTY, + Some(ref val) => val.as_task_type(), + } + } +} + +impl Future for TaskFuture { + type Output = Box<hyper_task>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.task.as_mut().unwrap().future).poll(cx) { + Poll::Ready(val) => { + let mut task = self.task.take().unwrap(); + task.output = Some(val); + Poll::Ready(task) + } + Poll::Pending => Poll::Pending, + } + } +} + +ffi_fn! { + /// Free a task. + fn hyper_task_free(task: *mut hyper_task) { + drop(non_null!(Box::from_raw(task) ?= ())); + } +} + +ffi_fn! { + /// Takes the output value of this task. + /// + /// This must only be called once polling the task on an executor has finished + /// this task. + /// + /// Use `hyper_task_type` to determine the type of the `void *` return value. + fn hyper_task_value(task: *mut hyper_task) -> *mut c_void { + let task = non_null!(&mut *task ?= ptr::null_mut()); + + if let Some(val) = task.output.take() { + let p = Box::into_raw(val) as *mut c_void; + // protect from returning fake pointers to empty types + if p == std::ptr::NonNull::<c_void>::dangling().as_ptr() { + ptr::null_mut() + } else { + p + } + } else { + ptr::null_mut() + } + } ?= ptr::null_mut() +} + +ffi_fn! { + /// Query the return type of this task. + fn hyper_task_type(task: *mut hyper_task) -> hyper_task_return_type { + // instead of blowing up spectacularly, just say this null task + // doesn't have a value to retrieve. + non_null!(&*task ?= hyper_task_return_type::HYPER_TASK_EMPTY).output_type() + } +} + +ffi_fn! { + /// Set a user data pointer to be associated with this task. + /// + /// This value will be passed to task callbacks, and can be checked later + /// with `hyper_task_userdata`. + fn hyper_task_set_userdata(task: *mut hyper_task, userdata: *mut c_void) { + if task.is_null() { + return; + } + + unsafe { (*task).userdata = UserDataPointer(userdata) }; + } +} + +ffi_fn! { + /// Retrieve the userdata that has been set via `hyper_task_set_userdata`. + fn hyper_task_userdata(task: *mut hyper_task) -> *mut c_void { + non_null!(&*task ?= ptr::null_mut()).userdata.0 + } ?= ptr::null_mut() +} + +// ===== impl AsTaskType ===== + +unsafe impl AsTaskType for () { + fn as_task_type(&self) -> hyper_task_return_type { + hyper_task_return_type::HYPER_TASK_EMPTY + } +} + +unsafe impl AsTaskType for crate::Error { + fn as_task_type(&self) -> hyper_task_return_type { + hyper_task_return_type::HYPER_TASK_ERROR + } +} + +impl<T> IntoDynTaskType for T +where + T: AsTaskType + Send + Sync + 'static, +{ + fn into_dyn_task_type(self) -> BoxAny { + Box::new(self) + } +} + +impl<T> IntoDynTaskType for crate::Result<T> +where + T: IntoDynTaskType + Send + Sync + 'static, +{ + fn into_dyn_task_type(self) -> BoxAny { + match self { + Ok(val) => val.into_dyn_task_type(), + Err(err) => Box::new(err), + } + } +} + +impl<T> IntoDynTaskType for Option<T> +where + T: IntoDynTaskType + Send + Sync + 'static, +{ + fn into_dyn_task_type(self) -> BoxAny { + match self { + Some(val) => val.into_dyn_task_type(), + None => ().into_dyn_task_type(), + } + } +} + +// ===== impl hyper_context ===== + +impl hyper_context<'_> { + pub(crate) fn wrap<'a, 'b>(cx: &'a mut Context<'b>) -> &'a mut hyper_context<'b> { + // A struct with only one field has the same layout as that field. + unsafe { std::mem::transmute::<&mut Context<'_>, &mut hyper_context<'_>>(cx) } + } +} + +ffi_fn! { + /// Copies a waker out of the task context. + fn hyper_context_waker(cx: *mut hyper_context<'_>) -> *mut hyper_waker { + let waker = non_null!(&mut *cx ?= ptr::null_mut()).0.waker().clone(); + Box::into_raw(Box::new(hyper_waker { waker })) + } ?= ptr::null_mut() +} + +// ===== impl hyper_waker ===== + +ffi_fn! { + /// Free a waker that hasn't been woken. + fn hyper_waker_free(waker: *mut hyper_waker) { + drop(non_null!(Box::from_raw(waker) ?= ())); + } +} + +ffi_fn! { + /// Wake up the task associated with a waker. + /// + /// NOTE: This consumes the waker. You should not use or free the waker afterwards. + fn hyper_waker_wake(waker: *mut hyper_waker) { + let waker = non_null!(Box::from_raw(waker) ?= ()); + waker.waker.wake(); + } +} diff --git a/third_party/rust/hyper/src/headers.rs b/third_party/rust/hyper/src/headers.rs new file mode 100644 index 0000000000..8407be185f --- /dev/null +++ b/third_party/rust/hyper/src/headers.rs @@ -0,0 +1,154 @@ +#[cfg(feature = "http1")] +use bytes::BytesMut; +use http::header::CONTENT_LENGTH; +use http::header::{HeaderValue, ValueIter}; +use http::HeaderMap; +#[cfg(all(feature = "http2", feature = "client"))] +use http::Method; + +#[cfg(feature = "http1")] +pub(super) fn connection_keep_alive(value: &HeaderValue) -> bool { + connection_has(value, "keep-alive") +} + +#[cfg(feature = "http1")] +pub(super) fn connection_close(value: &HeaderValue) -> bool { + connection_has(value, "close") +} + +#[cfg(feature = "http1")] +fn connection_has(value: &HeaderValue, needle: &str) -> bool { + if let Ok(s) = value.to_str() { + for val in s.split(',') { + if val.trim().eq_ignore_ascii_case(needle) { + return true; + } + } + } + false +} + +#[cfg(all(feature = "http1", feature = "server"))] +pub(super) fn content_length_parse(value: &HeaderValue) -> Option<u64> { + from_digits(value.as_bytes()) +} + +pub(super) fn content_length_parse_all(headers: &HeaderMap) -> Option<u64> { + content_length_parse_all_values(headers.get_all(CONTENT_LENGTH).into_iter()) +} + +pub(super) fn content_length_parse_all_values(values: ValueIter<'_, HeaderValue>) -> Option<u64> { + // If multiple Content-Length headers were sent, everything can still + // be alright if they all contain the same value, and all parse + // correctly. If not, then it's an error. + + let mut content_length: Option<u64> = None; + for h in values { + if let Ok(line) = h.to_str() { + for v in line.split(',') { + if let Some(n) = from_digits(v.trim().as_bytes()) { + if content_length.is_none() { + content_length = Some(n) + } else if content_length != Some(n) { + return None; + } + } else { + return None + } + } + } else { + return None + } + } + + return content_length +} + +fn from_digits(bytes: &[u8]) -> Option<u64> { + // cannot use FromStr for u64, since it allows a signed prefix + let mut result = 0u64; + const RADIX: u64 = 10; + + if bytes.is_empty() { + return None; + } + + for &b in bytes { + // can't use char::to_digit, since we haven't verified these bytes + // are utf-8. + match b { + b'0'..=b'9' => { + result = result.checked_mul(RADIX)?; + result = result.checked_add((b - b'0') as u64)?; + }, + _ => { + // not a DIGIT, get outta here! + return None; + } + } + } + + Some(result) +} + +#[cfg(all(feature = "http2", feature = "client"))] +pub(super) fn method_has_defined_payload_semantics(method: &Method) -> bool { + match *method { + Method::GET | Method::HEAD | Method::DELETE | Method::CONNECT => false, + _ => true, + } +} + +#[cfg(feature = "http2")] +pub(super) fn set_content_length_if_missing(headers: &mut HeaderMap, len: u64) { + headers + .entry(CONTENT_LENGTH) + .or_insert_with(|| HeaderValue::from(len)); +} + +#[cfg(feature = "http1")] +pub(super) fn transfer_encoding_is_chunked(headers: &HeaderMap) -> bool { + is_chunked(headers.get_all(http::header::TRANSFER_ENCODING).into_iter()) +} + +#[cfg(feature = "http1")] +pub(super) fn is_chunked(mut encodings: ValueIter<'_, HeaderValue>) -> bool { + // chunked must always be the last encoding, according to spec + if let Some(line) = encodings.next_back() { + return is_chunked_(line); + } + + false +} + +#[cfg(feature = "http1")] +pub(super) fn is_chunked_(value: &HeaderValue) -> bool { + // chunked must always be the last encoding, according to spec + if let Ok(s) = value.to_str() { + if let Some(encoding) = s.rsplit(',').next() { + return encoding.trim().eq_ignore_ascii_case("chunked"); + } + } + + false +} + +#[cfg(feature = "http1")] +pub(super) fn add_chunked(mut entry: http::header::OccupiedEntry<'_, HeaderValue>) { + const CHUNKED: &str = "chunked"; + + if let Some(line) = entry.iter_mut().next_back() { + // + 2 for ", " + let new_cap = line.as_bytes().len() + CHUNKED.len() + 2; + let mut buf = BytesMut::with_capacity(new_cap); + buf.extend_from_slice(line.as_bytes()); + buf.extend_from_slice(b", "); + buf.extend_from_slice(CHUNKED.as_bytes()); + + *line = HeaderValue::from_maybe_shared(buf.freeze()) + .expect("original header value plus ascii is valid"); + return; + } + + entry.insert(HeaderValue::from_static(CHUNKED)); +} diff --git a/third_party/rust/hyper/src/lib.rs b/third_party/rust/hyper/src/lib.rs new file mode 100644 index 0000000000..3a2202dff6 --- /dev/null +++ b/third_party/rust/hyper/src/lib.rs @@ -0,0 +1,109 @@ +#![deny(missing_docs)] +#![deny(missing_debug_implementations)] +#![cfg_attr(test, deny(rust_2018_idioms))] +#![cfg_attr(all(test, feature = "full"), deny(unreachable_pub))] +#![cfg_attr(all(test, feature = "full"), deny(warnings))] +#![cfg_attr(all(test, feature = "nightly"), feature(test))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! # hyper +//! +//! hyper is a **fast** and **correct** HTTP implementation written in and for Rust. +//! +//! ## Features +//! +//! - HTTP/1 and HTTP/2 +//! - Asynchronous design +//! - Leading in performance +//! - Tested and **correct** +//! - Extensive production use +//! - [Client](client/index.html) and [Server](server/index.html) APIs +//! +//! If just starting out, **check out the [Guides](https://hyper.rs/guides) +//! first.** +//! +//! ## "Low-level" +//! +//! hyper is a lower-level HTTP library, meant to be a building block +//! for libraries and applications. +//! +//! If looking for just a convenient HTTP client, consider the +//! [reqwest](https://crates.io/crates/reqwest) crate. +//! +//! # Optional Features +//! +//! hyper uses a set of [feature flags] to reduce the amount of compiled code. +//! It is possible to just enable certain features over others. By default, +//! hyper does not enable any features but allows one to enable a subset for +//! their use case. Below is a list of the available feature flags. You may +//! also notice above each function, struct and trait there is listed one or +//! more feature flags that are required for that item to be used. +//! +//! If you are new to hyper it is possible to enable the `full` feature flag +//! which will enable all public APIs. Beware though that this will pull in +//! many extra dependencies that you may not need. +//! +//! The following optional features are available: +//! +//! - `http1`: Enables HTTP/1 support. +//! - `http2`: Enables HTTP/2 support. +//! - `client`: Enables the HTTP `client`. +//! - `server`: Enables the HTTP `server`. +//! - `runtime`: Enables convenient integration with `tokio`, providing +//! connectors and acceptors for TCP, and a default executor. +//! - `tcp`: Enables convenient implementations over TCP (using tokio). +//! - `stream`: Provides `futures::Stream` capabilities. +//! +//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section + +#[doc(hidden)] +pub use http; + +#[cfg(all(test, feature = "nightly"))] +extern crate test; + +pub use crate::http::{header, Method, Request, Response, StatusCode, Uri, Version}; + +#[doc(no_inline)] +pub use crate::http::HeaderMap; + +pub use crate::body::Body; +pub use crate::error::{Error, Result}; + +#[macro_use] +mod cfg; +#[macro_use] +mod common; +pub mod body; +mod error; +pub mod ext; +#[cfg(test)] +mod mock; +pub mod rt; +pub mod service; +pub mod upgrade; + +#[cfg(feature = "ffi")] +pub mod ffi; + +cfg_proto! { + mod headers; + mod proto; +} + +cfg_feature! { + #![feature = "client"] + + pub mod client; + #[cfg(any(feature = "http1", feature = "http2"))] + #[doc(no_inline)] + pub use crate::client::Client; +} + +cfg_feature! { + #![feature = "server"] + + pub mod server; + #[doc(no_inline)] + pub use crate::server::Server; +} diff --git a/third_party/rust/hyper/src/mock.rs b/third_party/rust/hyper/src/mock.rs new file mode 100644 index 0000000000..1dd57de319 --- /dev/null +++ b/third_party/rust/hyper/src/mock.rs @@ -0,0 +1,235 @@ +// FIXME: re-implement tests with `async/await` +/* +#[cfg(feature = "runtime")] +use std::collections::HashMap; +use std::cmp; +use std::io::{self, Read, Write}; +#[cfg(feature = "runtime")] +use std::sync::{Arc, Mutex}; + +use bytes::Buf; +use futures::{Async, Poll}; +#[cfg(feature = "runtime")] +use futures::Future; +use futures::task::{self, Task}; +use tokio_io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "runtime")] +use crate::client::connect::{Connect, Connected, Destination}; + + + +#[cfg(feature = "runtime")] +pub struct Duplex { + inner: Arc<Mutex<DuplexInner>>, +} + +#[cfg(feature = "runtime")] +struct DuplexInner { + handle_read_task: Option<Task>, + read: AsyncIo<MockCursor>, + write: AsyncIo<MockCursor>, +} + +#[cfg(feature = "runtime")] +impl Duplex { + pub(crate) fn channel() -> (Duplex, DuplexHandle) { + let mut inner = DuplexInner { + handle_read_task: None, + read: AsyncIo::new_buf(Vec::new(), 0), + write: AsyncIo::new_buf(Vec::new(), std::usize::MAX), + }; + + inner.read.park_tasks(true); + inner.write.park_tasks(true); + + let inner = Arc::new(Mutex::new(inner)); + + let duplex = Duplex { + inner: inner.clone(), + }; + let handle = DuplexHandle { + inner: inner, + }; + + (duplex, handle) + } +} + +#[cfg(feature = "runtime")] +impl Read for Duplex { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.lock().unwrap().read.read(buf) + } +} + +#[cfg(feature = "runtime")] +impl Write for Duplex { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let mut inner = self.inner.lock().unwrap(); + let ret = inner.write.write(buf); + if let Some(task) = inner.handle_read_task.take() { + trace!("waking DuplexHandle read"); + task.notify(); + } + ret + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.lock().unwrap().write.flush() + } +} + +#[cfg(feature = "runtime")] +impl AsyncRead for Duplex { +} + +#[cfg(feature = "runtime")] +impl AsyncWrite for Duplex { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(().into()) + } + + fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + if let Some(task) = inner.handle_read_task.take() { + task.notify(); + } + inner.write.write_buf(buf) + } +} + +#[cfg(feature = "runtime")] +pub struct DuplexHandle { + inner: Arc<Mutex<DuplexInner>>, +} + +#[cfg(feature = "runtime")] +impl DuplexHandle { + pub fn read(&self, buf: &mut [u8]) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + assert!(buf.len() >= inner.write.inner.len()); + if inner.write.inner.is_empty() { + trace!("DuplexHandle read parking"); + inner.handle_read_task = Some(task::current()); + return Ok(Async::NotReady); + } + inner.write.read(buf).map(Async::Ready) + } + + pub fn write(&self, bytes: &[u8]) -> Poll<usize, io::Error> { + let mut inner = self.inner.lock().unwrap(); + assert_eq!(inner.read.inner.pos, 0); + assert_eq!(inner.read.inner.vec.len(), 0, "write but read isn't empty"); + inner + .read + .inner + .vec + .extend(bytes); + inner.read.block_in(bytes.len()); + Ok(Async::Ready(bytes.len())) + } +} + +#[cfg(feature = "runtime")] +impl Drop for DuplexHandle { + fn drop(&mut self) { + trace!("mock duplex handle drop"); + if !::std::thread::panicking() { + let mut inner = self.inner.lock().unwrap(); + inner.read.close(); + inner.write.close(); + } + } +} + +#[cfg(feature = "runtime")] +type BoxedConnectFut = Box<dyn Future<Item=(Duplex, Connected), Error=io::Error> + Send>; + +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub struct MockConnector { + mocks: Arc<Mutex<MockedConnections>>, +} + +#[cfg(feature = "runtime")] +struct MockedConnections(HashMap<String, Vec<BoxedConnectFut>>); + +#[cfg(feature = "runtime")] +impl MockConnector { + pub fn new() -> MockConnector { + MockConnector { + mocks: Arc::new(Mutex::new(MockedConnections(HashMap::new()))), + } + } + + pub fn mock(&mut self, key: &str) -> DuplexHandle { + use futures::future; + self.mock_fut(key, future::ok::<_, ()>(())) + } + + pub fn mock_fut<F>(&mut self, key: &str, fut: F) -> DuplexHandle + where + F: Future + Send + 'static, + { + self.mock_opts(key, Connected::new(), fut) + } + + pub fn mock_opts<F>(&mut self, key: &str, connected: Connected, fut: F) -> DuplexHandle + where + F: Future + Send + 'static, + { + let key = key.to_owned(); + + let (duplex, handle) = Duplex::channel(); + + let fut = Box::new(fut.then(move |_| { + trace!("MockConnector mocked fut ready"); + Ok((duplex, connected)) + })); + self.mocks.lock().unwrap().0.entry(key) + .or_insert(Vec::new()) + .push(fut); + + handle + } +} + +#[cfg(feature = "runtime")] +impl Connect for MockConnector { + type Transport = Duplex; + type Error = io::Error; + type Future = BoxedConnectFut; + + fn connect(&self, dst: Destination) -> Self::Future { + trace!("mock connect: {:?}", dst); + let key = format!("{}://{}{}", dst.scheme(), dst.host(), if let Some(port) = dst.port() { + format!(":{}", port) + } else { + "".to_owned() + }); + let mut mocks = self.mocks.lock().unwrap(); + let mocks = mocks.0.get_mut(&key) + .expect(&format!("unknown mocks uri: {}", key)); + assert!(!mocks.is_empty(), "no additional mocks for {}", key); + mocks.remove(0) + } +} + + +#[cfg(feature = "runtime")] +impl Drop for MockedConnections { + fn drop(&mut self) { + if !::std::thread::panicking() { + for (key, mocks) in self.0.iter() { + assert_eq!( + mocks.len(), + 0, + "not all mocked connects for {:?} were used", + key, + ); + } + } + } +} +*/ diff --git a/third_party/rust/hyper/src/proto/h1/conn.rs b/third_party/rust/hyper/src/proto/h1/conn.rs new file mode 100644 index 0000000000..5ebff2803e --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/conn.rs @@ -0,0 +1,1425 @@ +use std::fmt; +use std::io; +use std::marker::PhantomData; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Duration; + +use bytes::{Buf, Bytes}; +use http::header::{HeaderValue, CONNECTION}; +use http::{HeaderMap, Method, Version}; +use httparse::ParserConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; +use tracing::{debug, error, trace}; + +use super::io::Buffered; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; +use crate::body::DecodedLength; +use crate::common::{task, Pin, Poll, Unpin}; +use crate::headers::connection_keep_alive; +use crate::proto::{BodyLength, MessageHead}; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// This handles a connection, which will have been established over an +/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Transaction`s over HTTP. +/// +/// The connection will determine when a message begins and ends as well as +/// determine if this connection can be kept alive after the message, +/// or if it is complete. +pub(crate) struct Conn<I, B, T> { + io: Buffered<I, EncodedBuf<B>>, + state: State, + _marker: PhantomData<fn(T)>, +} + +impl<I, B, T> Conn<I, B, T> +where + I: AsyncRead + AsyncWrite + Unpin, + B: Buf, + T: Http1Transaction, +{ + pub(crate) fn new(io: I) -> Conn<I, B, T> { + Conn { + io: Buffered::new(io), + state: State { + allow_half_close: false, + cached_headers: None, + error: None, + keep_alive: KA::Busy, + method: None, + h1_parser_config: ParserConfig::default(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + title_case_headers: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: None, + #[cfg(feature = "ffi")] + raw_headers: false, + notify_read: false, + reading: Reading::Init, + writing: Writing::Init, + upgrade: None, + // We assume a modern world where the remote speaks HTTP/1.1. + // If they tell us otherwise, we'll downgrade in `read_head`. + version: Version::HTTP_11, + }, + _marker: PhantomData, + } + } + + #[cfg(feature = "server")] + pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { + self.io.set_flush_pipeline(enabled); + } + + pub(crate) fn set_write_strategy_queue(&mut self) { + self.io.set_write_strategy_queue(); + } + + pub(crate) fn set_max_buf_size(&mut self, max: usize) { + self.io.set_max_buf_size(max); + } + + #[cfg(feature = "client")] + pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { + self.io.set_read_buf_exact_size(sz); + } + + pub(crate) fn set_write_strategy_flatten(&mut self) { + self.io.set_write_strategy_flatten(); + } + + #[cfg(feature = "client")] + pub(crate) fn set_h1_parser_config(&mut self, parser_config: ParserConfig) { + self.state.h1_parser_config = parser_config; + } + + pub(crate) fn set_title_case_headers(&mut self) { + self.state.title_case_headers = true; + } + + pub(crate) fn set_preserve_header_case(&mut self) { + self.state.preserve_header_case = true; + } + + #[cfg(feature = "ffi")] + pub(crate) fn set_preserve_header_order(&mut self) { + self.state.preserve_header_order = true; + } + + #[cfg(feature = "client")] + pub(crate) fn set_h09_responses(&mut self) { + self.state.h09_responses = true; + } + + #[cfg(all(feature = "server", feature = "runtime"))] + pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) { + self.state.h1_header_read_timeout = Some(val); + } + + #[cfg(feature = "server")] + pub(crate) fn set_allow_half_close(&mut self) { + self.state.allow_half_close = true; + } + + #[cfg(feature = "ffi")] + pub(crate) fn set_raw_headers(&mut self, enabled: bool) { + self.state.raw_headers = enabled; + } + + pub(crate) fn into_inner(self) -> (I, Bytes) { + self.io.into_inner() + } + + pub(crate) fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> { + self.state.upgrade.take() + } + + pub(crate) fn is_read_closed(&self) -> bool { + self.state.is_read_closed() + } + + pub(crate) fn is_write_closed(&self) -> bool { + self.state.is_write_closed() + } + + pub(crate) fn can_read_head(&self) -> bool { + if !matches!(self.state.reading, Reading::Init) { + return false; + } + + if T::should_read_first() { + return true; + } + + !matches!(self.state.writing, Writing::Init) + } + + pub(crate) fn can_read_body(&self) -> bool { + match self.state.reading { + Reading::Body(..) | Reading::Continue(..) => true, + _ => false, + } + } + + fn should_error_on_eof(&self) -> bool { + // If we're idle, it's probably just the connection closing gracefully. + T::should_error_on_parse_eof() && !self.state.is_idle() + } + + fn has_h2_prefix(&self) -> bool { + let read_buf = self.io.read_buf(); + read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE + } + + pub(super) fn poll_read_head( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> { + debug_assert!(self.can_read_head()); + trace!("Conn::read_head"); + + let msg = match ready!(self.io.parse::<T>( + cx, + ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + h1_parser_config: self.state.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: self.state.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running, + preserve_header_case: self.state.preserve_header_case, + #[cfg(feature = "ffi")] + preserve_header_order: self.state.preserve_header_order, + h09_responses: self.state.h09_responses, + #[cfg(feature = "ffi")] + on_informational: &mut self.state.on_informational, + #[cfg(feature = "ffi")] + raw_headers: self.state.raw_headers, + } + )) { + Ok(msg) => msg, + Err(e) => return self.on_read_head_error(e), + }; + + // Note: don't deconstruct `msg` into local variables, it appears + // the optimizer doesn't remove the extra copies. + + debug!("incoming body is {}", msg.decode); + + // Prevent accepting HTTP/0.9 responses after the initial one, if any. + self.state.h09_responses = false; + + // Drop any OnInformational callbacks, we're done there! + #[cfg(feature = "ffi")] + { + self.state.on_informational = None; + } + + self.state.busy(); + self.state.keep_alive &= msg.keep_alive; + self.state.version = msg.head.version; + + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + + if msg.decode == DecodedLength::ZERO { + if msg.expect_continue { + debug!("ignoring expect-continue since body is empty"); + } + self.state.reading = Reading::KeepAlive; + if !T::should_read_first() { + self.try_keep_alive(cx); + } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); + } else { + self.state.reading = Reading::Body(Decoder::new(msg.decode)); + } + + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) + } + + fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); + if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface + debug!( + "parse error ({}) with {} bytes", + e, + self.io.read_buf().len() + ); + match self.on_parse_error(e) { + Ok(()) => Poll::Pending, // XXX: wat? + Err(e) => Poll::Ready(Some(Err(e))), + } + } else { + debug!("read eof"); + self.close_write(); + Poll::Ready(None) + } + } + + pub(crate) fn poll_read_body( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<io::Result<Bytes>>> { + debug_assert!(self.can_read_body()); + + let (reading, ret) = match self.state.reading { + Reading::Body(ref mut decoder) => { + match ready!(decoder.decode(cx, &mut self.io)) { + Ok(slice) => { + let (reading, chunk) = if decoder.is_eof() { + debug!("incoming body completed"); + ( + Reading::KeepAlive, + if !slice.is_empty() { + Some(Ok(slice)) + } else { + None + }, + ) + } else if slice.is_empty() { + error!("incoming body unexpectedly ended"); + // This should be unreachable, since all 3 decoders + // either set eof=true or return an Err when reading + // an empty slice... + (Reading::Closed, None) + } else { + return Poll::Ready(Some(Ok(slice))); + }; + (reading, Poll::Ready(chunk)) + } + Err(e) => { + debug!("incoming body decode error: {}", e); + (Reading::Closed, Poll::Ready(Some(Err(e)))) + } + } + } + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), + }; + + self.state.reading = reading; + self.try_keep_alive(cx); + ret + } + + pub(crate) fn wants_read_again(&mut self) -> bool { + let ret = self.state.notify_read; + self.state.notify_read = false; + ret + } + + pub(crate) fn poll_read_keep_alive( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body()); + + if self.is_read_closed() { + Poll::Pending + } else if self.is_mid_message() { + self.mid_message_detect_eof(cx) + } else { + self.require_empty_read(cx) + } + } + + fn is_mid_message(&self) -> bool { + !matches!( + (&self.state.reading, &self.state.writing), + (&Reading::Init, &Writing::Init) + ) + } + + // This will check to make sure the io object read is empty. + // + // This should only be called for Clients wanting to enter the idle + // state. + fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(!self.is_mid_message()); + debug_assert!(T::is_client()); + + if !self.io.read_buf().is_empty() { + debug!("received an unexpected {} bytes", self.io.read_buf().len()); + return Poll::Ready(Err(crate::Error::new_unexpected_message())); + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + let ret = if self.should_error_on_eof() { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + trace!("found EOF on idle connection, closing"); + Poll::Ready(Ok(())) + }; + + // order is important: should_error needs state BEFORE close_read + self.state.close_read(); + return ret; + } + + debug!( + "received unexpected {} bytes on an idle connection", + num_read + ); + Poll::Ready(Err(crate::Error::new_unexpected_message())) + } + + fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(self.is_mid_message()); + + if self.state.allow_half_close || !self.io.read_buf().is_empty() { + return Poll::Pending; + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + self.state.close_read(); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + Poll::Ready(Ok(())) + } + } + + fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + debug_assert!(!self.state.is_read_closed()); + + let result = ready!(self.io.poll_read_from_io(cx)); + Poll::Ready(result.map_err(|e| { + trace!("force_io_read; io error = {:?}", e); + self.state.close(); + e + })) + } + + fn maybe_notify(&mut self, cx: &mut task::Context<'_>) { + // its possible that we returned NotReady from poll() without having + // exhausted the underlying Io. We would have done this when we + // determined we couldn't keep reading until we knew how writing + // would finish. + + match self.state.reading { + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } + Reading::Init => (), + }; + + match self.state.writing { + Writing::Body(..) => return, + Writing::Init | Writing::KeepAlive | Writing::Closed => (), + } + + if !self.io.is_read_blocked() { + if self.io.read_buf().is_empty() { + match self.io.poll_read_from_io(cx) { + Poll::Ready(Ok(n)) => { + if n == 0 { + trace!("maybe_notify; read eof"); + if self.state.is_idle() { + self.state.close(); + } else { + self.close_read() + } + return; + } + } + Poll::Pending => { + trace!("maybe_notify; read_from_io blocked"); + return; + } + Poll::Ready(Err(e)) => { + trace!("maybe_notify; read_from_io error: {}", e); + self.state.close(); + self.state.error = Some(crate::Error::new_io(e)); + } + } + } + self.state.notify_read = true; + } + } + + fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) { + self.state.try_keep_alive::<T>(); + self.maybe_notify(cx); + } + + pub(crate) fn can_write_head(&self) -> bool { + if !T::should_read_first() && matches!(self.state.reading, Reading::Closed) { + return false; + } + + match self.state.writing { + Writing::Init => self.io.can_headers_buf(), + _ => false, + } + } + + pub(crate) fn can_write_body(&self) -> bool { + match self.state.writing { + Writing::Body(..) => true, + Writing::Init | Writing::KeepAlive | Writing::Closed => false, + } + } + + pub(crate) fn can_buffer_body(&self) -> bool { + self.io.can_buffer() + } + + pub(crate) fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) { + if let Some(encoder) = self.encode_head(head, body) { + self.state.writing = if !encoder.is_eof() { + Writing::Body(encoder) + } else if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + + pub(crate) fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) { + if let Some(encoder) = + self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64))) + { + let is_last = encoder.is_last(); + // Make sure we don't write a body if we weren't actually allowed + // to do so, like because its a HEAD request. + if !encoder.is_eof() { + encoder.danger_full_buf(body, self.io.write_buf()); + } + self.state.writing = if is_last { + Writing::Closed + } else { + Writing::KeepAlive + } + } + } + + fn encode_head( + &mut self, + mut head: MessageHead<T::Outgoing>, + body: Option<BodyLength>, + ) -> Option<Encoder> { + debug_assert!(self.can_write_head()); + + if !T::should_read_first() { + self.state.busy(); + } + + self.enforce_version(&mut head); + + let buf = self.io.headers_buf(); + match super::role::encode_headers::<T>( + Encode { + head: &mut head, + body, + #[cfg(feature = "server")] + keep_alive: self.state.wants_keep_alive(), + req_method: &mut self.state.method, + title_case_headers: self.state.title_case_headers, + }, + buf, + ) { + Ok(encoder) => { + debug_assert!(self.state.cached_headers.is_none()); + debug_assert!(head.headers.is_empty()); + self.state.cached_headers = Some(head.headers); + + #[cfg(feature = "ffi")] + { + self.state.on_informational = + head.extensions.remove::<crate::ffi::OnInformational>(); + } + + Some(encoder) + } + Err(err) => { + self.state.error = Some(err); + self.state.writing = Writing::Closed; + None + } + } + } + + // Fix keep-alive when Connection: keep-alive header is not present + fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) { + let outgoing_is_keep_alive = head + .headers + .get(CONNECTION) + .map(connection_keep_alive) + .unwrap_or(false); + + if !outgoing_is_keep_alive { + match head.version { + // If response is version 1.0 and keep-alive is not present in the response, + // disable keep-alive so the server closes the connection + Version::HTTP_10 => self.state.disable_keep_alive(), + // If response is version 1.1 and keep-alive is wanted, add + // Connection: keep-alive header when not present + Version::HTTP_11 => { + if self.state.wants_keep_alive() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } + _ => (), + } + } + } + + // If we know the remote speaks an older version, we try to fix up any messages + // to work with our older peer. + fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) { + if let Version::HTTP_10 = self.state.version { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = Version::HTTP_10; + } + // If the remote speaks HTTP/1.1, then it *should* be fine with + // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let + // the user's headers be. + } + + pub(crate) fn write_body(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + self.io.buffer(encoder.encode(chunk)); + + if !encoder.is_eof() { + return; + } + + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub(crate) fn write_body_and_end(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref encoder) => { + let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf()); + if can_keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub(crate) fn end_body(&mut self) -> crate::Result<()> { + debug_assert!(self.can_write_body()); + + let encoder = match self.state.writing { + Writing::Body(ref mut enc) => enc, + _ => return Ok(()), + }; + + // end of stream, that means we should try to eof + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + + self.state.writing = if encoder.is_last() || encoder.is_close_delimited() { + Writing::Closed + } else { + Writing::KeepAlive + }; + + Ok(()) + } + Err(not_eof) => { + self.state.writing = Writing::Closed; + Err(crate::Error::new_body_write_aborted().with(not_eof)) + } + } + } + + // When we get a parse error, depending on what side we are, we might be able + // to write a response before closing the connection. + // + // - Client: there is nothing we can do + // - Server: if Response hasn't been written yet, we can send a 4xx response + fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> { + if let Writing::Init = self.state.writing { + if self.has_h2_prefix() { + return Err(crate::Error::new_version_h2()); + } + if let Some(msg) = T::on_error(&err) { + // Drop the cached headers so as to not trigger a debug + // assert in `write_head`... + self.state.cached_headers.take(); + self.write_head(msg, None); + self.state.error = Some(err); + return Ok(()); + } + } + + // fallback is pass the error back up + Err(err) + } + + pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + ready!(Pin::new(&mut self.io).poll_flush(cx))?; + self.try_keep_alive(cx); + trace!("flushed({}): {:?}", T::LOG, self.state); + Poll::Ready(Ok(())) + } + + pub(crate) fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) { + Ok(()) => { + trace!("shut down IO complete"); + Poll::Ready(Ok(())) + } + Err(e) => { + debug!("error shutting down IO: {}", e); + Poll::Ready(Err(e)) + } + } + } + + /// If the read side can be cheaply drained, do so. Otherwise, close. + pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) { + if let Reading::Continue(ref decoder) = self.state.reading { + // skip sending the 100-continue + // just move forward to a read, in case a tiny body was included + self.state.reading = Reading::Body(decoder.clone()); + } + + let _ = self.poll_read_body(cx); + + // If still in Reading::Body, just give up + match self.state.reading { + Reading::Init | Reading::KeepAlive => trace!("body drained"), + _ => self.close_read(), + } + } + + pub(crate) fn close_read(&mut self) { + self.state.close_read(); + } + + pub(crate) fn close_write(&mut self) { + self.state.close_write(); + } + + #[cfg(feature = "server")] + pub(crate) fn disable_keep_alive(&mut self) { + if self.state.is_idle() { + trace!("disable_keep_alive; closing idle connection"); + self.state.close(); + } else { + trace!("disable_keep_alive; in-progress connection"); + self.state.disable_keep_alive(); + } + } + + pub(crate) fn take_error(&mut self) -> crate::Result<()> { + if let Some(err) = self.state.error.take() { + Err(err) + } else { + Ok(()) + } + } + + pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + trace!("{}: prepare possible HTTP upgrade", T::LOG); + self.state.prepare_upgrade() + } +} + +impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("state", &self.state) + .field("io", &self.io) + .finish() + } +} + +// B and T are never pinned +impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {} + +struct State { + allow_half_close: bool, + /// Re-usable HeaderMap to reduce allocating new ones. + cached_headers: Option<HeaderMap>, + /// If an error occurs when there wasn't a direct way to return it + /// back to the user, this is set. + error: Option<crate::Error>, + /// Current keep-alive status. + keep_alive: KA, + /// If mid-message, the HTTP Method that started it. + /// + /// This is used to know things such as if the message can include + /// a body or not. + method: Option<Method>, + h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option<Duration>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: Option<Pin<Box<Sleep>>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: bool, + preserve_header_case: bool, + #[cfg(feature = "ffi")] + preserve_header_order: bool, + title_case_headers: bool, + h09_responses: bool, + /// If set, called with each 1xx informational response received for + /// the current request. MUST be unset after a non-1xx response is + /// received. + #[cfg(feature = "ffi")] + on_informational: Option<crate::ffi::OnInformational>, + #[cfg(feature = "ffi")] + raw_headers: bool, + /// Set to true when the Dispatcher should poll read operations + /// again. See the `maybe_notify` method for more. + notify_read: bool, + /// State of allowed reads + reading: Reading, + /// State of allowed writes + writing: Writing, + /// An expected pending HTTP upgrade. + upgrade: Option<crate::upgrade::Pending>, + /// Either HTTP/1.0 or 1.1 connection + version: Version, +} + +#[derive(Debug)] +enum Reading { + Init, + Continue(Decoder), + Body(Decoder), + KeepAlive, + Closed, +} + +enum Writing { + Init, + Body(Encoder), + KeepAlive, + Closed, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("State"); + builder + .field("reading", &self.reading) + .field("writing", &self.writing) + .field("keep_alive", &self.keep_alive); + + // Only show error field if it's interesting... + if let Some(ref error) = self.error { + builder.field("error", error); + } + + if self.allow_half_close { + builder.field("allow_half_close", &true); + } + + // Purposefully leaving off other fields.. + + builder.finish() + } +} + +impl fmt::Debug for Writing { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Writing::Init => f.write_str("Init"), + Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(), + Writing::KeepAlive => f.write_str("KeepAlive"), + Writing::Closed => f.write_str("Closed"), + } + } +} + +impl std::ops::BitAndAssign<bool> for KA { + fn bitand_assign(&mut self, enabled: bool) { + if !enabled { + trace!("remote disabling keep-alive"); + *self = KA::Disabled; + } + } +} + +#[derive(Clone, Copy, Debug)] +enum KA { + Idle, + Busy, + Disabled, +} + +impl Default for KA { + fn default() -> KA { + KA::Busy + } +} + +impl KA { + fn idle(&mut self) { + *self = KA::Idle; + } + + fn busy(&mut self) { + *self = KA::Busy; + } + + fn disable(&mut self) { + *self = KA::Disabled; + } + + fn status(&self) -> KA { + *self + } +} + +impl State { + fn close(&mut self) { + trace!("State::close()"); + self.reading = Reading::Closed; + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn close_read(&mut self) { + trace!("State::close_read()"); + self.reading = Reading::Closed; + self.keep_alive.disable(); + } + + fn close_write(&mut self) { + trace!("State::close_write()"); + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn wants_keep_alive(&self) -> bool { + if let KA::Disabled = self.keep_alive.status() { + false + } else { + true + } + } + + fn try_keep_alive<T: Http1Transaction>(&mut self) { + match (&self.reading, &self.writing) { + (&Reading::KeepAlive, &Writing::KeepAlive) => { + if let KA::Busy = self.keep_alive.status() { + self.idle::<T>(); + } else { + trace!( + "try_keep_alive({}): could keep-alive, but status = {:?}", + T::LOG, + self.keep_alive + ); + self.close(); + } + } + (&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => { + self.close() + } + _ => (), + } + } + + fn disable_keep_alive(&mut self) { + self.keep_alive.disable() + } + + fn busy(&mut self) { + if let KA::Disabled = self.keep_alive.status() { + return; + } + self.keep_alive.busy(); + } + + fn idle<T: Http1Transaction>(&mut self) { + debug_assert!(!self.is_idle(), "State::idle() called while idle"); + + self.method = None; + self.keep_alive.idle(); + + if !self.is_idle() { + self.close(); + return; + } + + self.reading = Reading::Init; + self.writing = Writing::Init; + + // !T::should_read_first() means Client. + // + // If Client connection has just gone idle, the Dispatcher + // should try the poll loop one more time, so as to poll the + // pending requests stream. + if !T::should_read_first() { + self.notify_read = true; + } + } + + fn is_idle(&self) -> bool { + matches!(self.keep_alive.status(), KA::Idle) + } + + fn is_read_closed(&self) -> bool { + matches!(self.reading, Reading::Closed) + } + + fn is_write_closed(&self) -> bool { + matches!(self.writing, Writing::Closed) + } + + fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + let (tx, rx) = crate::upgrade::pending(); + self.upgrade = Some(tx); + rx + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + #[bench] + fn bench_read_head_short(b: &mut ::test::Bencher) { + use super::*; + let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; + let len = s.len(); + b.bytes = len as u64; + + // an empty IO, we'll be skipping and using the read buffer anyways + let io = tokio_test::io::Builder::new().build(); + let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); + *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); + conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + b.iter(|| { + rt.block_on(futures_util::future::poll_fn(|cx| { + match conn.poll_read_head(cx) { + Poll::Ready(Some(Ok(x))) => { + ::test::black_box(&x); + let mut headers = x.0.headers; + headers.clear(); + conn.state.cached_headers = Some(headers); + } + f => panic!("expected Ready(Some(Ok(..))): {:?}", f), + } + + conn.io.read_buf_mut().reserve(1); + unsafe { + conn.io.read_buf_mut().set_len(len); + } + conn.state.reading = Reading::Init; + Poll::Ready(()) + })); + }); + } + + /* + //TODO: rewrite these using dispatch... someday... + use futures::{Async, Future, Stream, Sink}; + use futures::future; + + use proto::{self, ClientTransaction, MessageHead, ServerTransaction}; + use super::super::Encoder; + use mock::AsyncIo; + + use super::{Conn, Decoder, Reading, Writing}; + use ::uri::Uri; + + use std::str::FromStr; + + #[test] + fn test_conn_init_read() { + let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec(); + let len = good_message.len(); + let io = AsyncIo::new_buf(good_message, len); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + + match conn.poll().unwrap() { + Async::Ready(Some(Frame::Message { message, body: false })) => { + assert_eq!(message, MessageHead { + subject: ::proto::RequestLine(::Get, Uri::from_str("/").unwrap()), + .. MessageHead::default() + }) + }, + f => panic!("frame is not Frame::Message: {:?}", f) + } + } + + #[test] + fn test_conn_parse_partial() { + let _: Result<(), ()> = future::lazy(|| { + let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec(); + let io = AsyncIo::new_buf(good_message, 10); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + assert!(conn.poll().unwrap().is_not_ready()); + conn.io.io_mut().block_in(50); + let async = conn.poll().unwrap(); + assert!(async.is_ready()); + match async { + Async::Ready(Some(Frame::Message { .. })) => (), + f => panic!("frame is not Message: {:?}", f), + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_init_read_eof_idle() { + let io = AsyncIo::new_buf(vec![], 1); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_idle_partial_parse() { + let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_busy() { + let _: Result<(), ()> = future::lazy(|| { + // server ignores + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("unexpected frame: {:?}", other) + } + + // client + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_finish_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + conn.state.reading = Reading::Body(Decoder::length(0)); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_message_empty_body_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: false, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_read_body_end() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: true, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: Some(_) }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // When the body is done, `poll` MUST return a `Body` frame with chunk set to `None` + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::NotReady) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_closed_read() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_body_write_length() { + let _ = pretty_env_logger::try_init(); + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096; + conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64)); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; max].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'b'; 1024 * 8].into()) }).unwrap().is_not_ready()); + + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(max * 2); + assert!(conn.poll_complete().unwrap().is_ready()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'c'; 1024 * 8].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_write_chunked() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::chunked()); + + assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'x'; 8192].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_flush() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::length(1024 * 1024)); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + conn.io.io_mut().block_in(1024 * 1024 * 5); + assert!(conn.poll_complete().unwrap().is_ready()); + assert!(conn.can_buffer_body()); + assert!(conn.io.io_mut().flushed()); + + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_parking() { + use std::sync::Arc; + use futures::executor::Notify; + use futures::executor::NotifyHandle; + + struct Car { + permit: bool, + } + impl Notify for Car { + fn notify(&self, _id: usize) { + assert!(self.permit, "unparked without permit"); + } + } + + fn car(permit: bool) -> NotifyHandle { + Arc::new(Car { + permit: permit, + }).into() + } + + // test that once writing is done, unparks + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(true), 0).unwrap(); + + + // test that flushing when not waiting on read doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + + + // test that flushing and writing isn't done doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + conn.state.writing = Writing::Body(Encoder::length(5_000)); + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + } + + #[test] + fn test_conn_closed_write() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) { + Err(_e) => {}, + other => panic!("did not return Err: {:?}", other) + } + + assert!(conn.state.is_write_closed()); + } + + #[test] + fn test_conn_write_empty_chunk() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + + assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(Vec::new().into()) }).unwrap().is_ready()); + conn.start_send(Frame::Body { chunk: Some(vec![b'a'].into()) }).unwrap_err(); + } + */ +} diff --git a/third_party/rust/hyper/src/proto/h1/decode.rs b/third_party/rust/hyper/src/proto/h1/decode.rs new file mode 100644 index 0000000000..1e3a38effc --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/decode.rs @@ -0,0 +1,731 @@ +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::usize; + +use bytes::Bytes; +use tracing::{debug, trace}; + +use crate::common::{task, Poll}; + +use super::io::MemRead; +use super::DecodedLength; + +use self::Kind::{Chunked, Eof, Length}; + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Clone, PartialEq)] +pub(crate) struct Decoder { + kind: Kind, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// The bool tracks when EOF is seen on the transport. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof(bool), +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + Trailer, + TrailerLf, + EndCr, + EndLf, + End, +} + +impl Decoder { + // constructors + + pub(crate) fn length(x: u64) -> Decoder { + Decoder { + kind: Kind::Length(x), + } + } + + pub(crate) fn chunked() -> Decoder { + Decoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub(crate) fn eof() -> Decoder { + Decoder { + kind: Kind::Eof(false), + } + } + + pub(super) fn new(len: DecodedLength) -> Self { + match len { + DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CLOSE_DELIMITED => Decoder::eof(), + length => Decoder::length(length.danger_len()), + } + } + + // methods + + pub(crate) fn is_eof(&self) -> bool { + matches!(self.kind, Length(0) | Chunked(ChunkedState::End, _) | Eof(true)) + } + + pub(crate) fn decode<R: MemRead>( + &mut self, + cx: &mut task::Context<'_>, + body: &mut R, + ) -> Poll<Result<Bytes, io::Error>> { + trace!("decode; state={:?}", self.kind); + match self.kind { + Length(ref mut remaining) => { + if *remaining == 0 { + Poll::Ready(Ok(Bytes::new())) + } else { + let to_read = *remaining as usize; + let buf = ready!(body.read_mem(cx, to_read))?; + let num = buf.as_ref().len() as u64; + if num > *remaining { + *remaining = 0; + } else if num == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } else { + *remaining -= num; + } + Poll::Ready(Ok(buf)) + } + } + Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = ready!(state.step(cx, body, size, &mut buf))?; + if *state == ChunkedState::End { + trace!("end of chunked"); + return Poll::Ready(Ok(Bytes::new())); + } + if let Some(buf) = buf { + return Poll::Ready(Ok(buf)); + } + } + } + Eof(ref mut is_eof) => { + if *is_eof { + Poll::Ready(Ok(Bytes::new())) + } else { + // 8192 chosen because its about 2 packets, there probably + // won't be that much available, so don't have MemReaders + // allocate buffers to big + body.read_mem(cx, 8192).map_ok(|slice| { + *is_eof = slice.is_empty(); + slice + }) + } + } + } + } + + #[cfg(test)] + async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> { + futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await + } +} + +impl fmt::Debug for Decoder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.kind, f) + } +} + +macro_rules! byte ( + ($rdr:ident, $cx:expr) => ({ + let buf = ready!($rdr.read_mem($cx, 1))?; + if !buf.is_empty() { + buf[0] + } else { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, + "unexpected EOF during chunk size line"))); + } + }) +); + +impl ChunkedState { + fn step<R: MemRead>( + &self, + cx: &mut task::Context<'_>, + body: &mut R, + size: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(cx, body, size), + SizeLws => ChunkedState::read_size_lws(cx, body), + Extension => ChunkedState::read_extension(cx, body), + SizeLf => ChunkedState::read_size_lf(cx, body, *size), + Body => ChunkedState::read_body(cx, body, size, buf), + BodyCr => ChunkedState::read_body_cr(cx, body), + BodyLf => ChunkedState::read_body_lf(cx, body), + Trailer => ChunkedState::read_trailer(cx, body), + TrailerLf => ChunkedState::read_trailer_lf(cx, body), + EndCr => ChunkedState::read_end_cr(cx, body), + EndLf => ChunkedState::read_end_lf(cx, body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + fn read_size<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: &mut u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Read chunk hex size"); + + macro_rules! or_overflow { + ($e:expr) => ( + match $e { + Some(val) => val, + None => return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk size: overflow", + ))), + } + ) + } + + let radix = 16; + match byte!(rdr, cx) { + b @ b'0'..=b'9' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b - b'0') as u64)); + } + b @ b'a'..=b'f' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64)); + } + b @ b'A'..=b'F' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64)); + } + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + } + Poll::Ready(Ok(ChunkedState::Size)) + } + fn read_size_lws<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_size_lws"); + match byte!(rdr, cx) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_extension"); + // We don't care about extensions really at all. Just ignore them. + // They "end" at the next CRLF. + // + // However, some implementations may not check for the CR, so to save + // them from themselves, we reject extensions containing plain LF as + // well. + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + b'\n' => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk extension contains newline", + ))), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunk size is {:?}", size); + match byte!(rdr, cx) { + b'\n' => { + if size == 0 { + Poll::Ready(Ok(ChunkedState::EndCr)) + } else { + debug!("incoming chunked header: {0:#X} ({0} bytes)", size); + Poll::Ready(Ok(ChunkedState::Body)) + } + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + rem: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunked read, remaining={:?}", rem); + + // cap remaining bytes at the max capacity of usize + let rem_cap = match *rem { + r if r > usize::MAX as u64 => usize::MAX, + r => r as usize, + }; + + let to_read = rem_cap; + let slice = ready!(rdr.read_mem(cx, to_read))?; + let count = slice.len(); + + if count == 0 { + *rem = 0; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } + *buf = Some(slice); + *rem -= count as u64; + + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + fn read_body_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + + fn read_trailer<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_trailer"); + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)), + _ => Poll::Ready(Ok(ChunkedState::Trailer)), + } + } + fn read_trailer_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid trailer end LF", + ))), + } + } + + fn read_end_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Ok(ChunkedState::Trailer)), + } + } + fn read_end_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[derive(Debug)] +struct IncompleteBody; + +impl fmt::Display for IncompleteBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "end of file before message length reached") + } +} + +impl StdError for IncompleteBody {} + +#[cfg(test)] +mod tests { + use super::*; + use std::pin::Pin; + use std::time::Duration; + use tokio::io::{AsyncRead, ReadBuf}; + + impl<'a> MemRead for &'a [u8] { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + if n > 0 { + let (a, b) = self.split_at(n); + let buf = Bytes::copy_from_slice(a); + *self = b; + Poll::Ready(Ok(buf)) + } else { + Poll::Ready(Ok(Bytes::new())) + } + } + } + + impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let mut v = vec![0; len]; + let mut buf = ReadBuf::new(&mut v); + ready!(Pin::new(self).poll_read(cx, &mut buf)?); + Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled()))) + } + } + + #[cfg(feature = "nightly")] + impl MemRead for Bytes { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + let ret = self.split_to(n); + Poll::Ready(Ok(ret)) + } + } + + /* + use std::io; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use futures::{Async, Poll}; + use bytes::{BytesMut, Bytes}; + use crate::mock::AsyncIo; + */ + + #[tokio::test] + async fn test_read_chunk_size() { + use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof}; + + async fn read(s: &str) -> u64 { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + let desc = format!("read_size failed for {:?}", s); + state = result.expect(desc.as_str()); + if state == ChunkedState::Body || state == ChunkedState::EndCr { + break; + } + } + size + } + + async fn read_err(s: &str, expected_err: io::ErrorKind) { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + state = match result { + Ok(s) => s, + Err(e) => { + assert!( + expected_err == e.kind(), + "Reading {:?}, expected {:?}, but got {:?}", + s, + expected_err, + e.kind() + ); + return; + } + }; + if state == ChunkedState::Body || state == ChunkedState::End { + panic!("Was Ok. Expected Err for {:?}", s); + } + } + } + + assert_eq!(1, read("1\r\n").await); + assert_eq!(1, read("01\r\n").await); + assert_eq!(0, read("0\r\n").await); + assert_eq!(0, read("00\r\n").await); + assert_eq!(10, read("A\r\n").await); + assert_eq!(10, read("a\r\n").await); + assert_eq!(255, read("Ff\r\n").await); + assert_eq!(255, read("Ff \r\n").await); + // Missing LF or CRLF + read_err("F\rF", InvalidInput).await; + read_err("F", UnexpectedEof).await; + // Invalid hex digit + read_err("X\r\n", InvalidInput).await; + read_err("1X\r\n", InvalidInput).await; + read_err("-\r\n", InvalidInput).await; + read_err("-1\r\n", InvalidInput).await; + // Acceptable (if not fully valid) extensions do not influence the size + assert_eq!(1, read("1;extension\r\n").await); + assert_eq!(10, read("a;ext name=value\r\n").await); + assert_eq!(1, read("1;extension;extension2\r\n").await); + assert_eq!(1, read("1;;; ;\r\n").await); + assert_eq!(2, read("2; extension...\r\n").await); + assert_eq!(3, read("3 ; extension=123\r\n").await); + assert_eq!(3, read("3 ;\r\n").await); + assert_eq!(3, read("3 ; \r\n").await); + // Invalid extensions cause an error + read_err("1 invalid extension\r\n", InvalidInput).await; + read_err("1 A\r\n", InvalidInput).await; + read_err("1;no CRLF", UnexpectedEof).await; + read_err("1;reject\nnewlines\r\n", InvalidData).await; + // Overflow + read_err("f0000000000000003\r\n", InvalidData).await; + } + + #[tokio::test] + async fn test_read_sized_early_eof() { + let mut bytes = &b"foo bar"[..]; + let mut decoder = Decoder::length(10); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_early_eof() { + let mut bytes = &b"\ + 9\r\n\ + foo bar\ + "[..]; + let mut decoder = Decoder::chunked(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_single_read() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; + let buf = Decoder::chunked() + .decode_fut(&mut mock_buf) + .await + .expect("decode"); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + } + + #[tokio::test] + async fn test_read_chunked_trailer_with_missing_lf() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..]; + let mut decoder = Decoder::chunked(); + decoder.decode_fut(&mut mock_buf).await.expect("decode"); + let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::InvalidInput); + } + + #[tokio::test] + async fn test_read_chunked_after_eof() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; + let mut decoder = Decoder::chunked(); + + // normal read + let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + + // eof read + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + + // ensure read after eof also returns eof + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + } + + // perform an async read using a custom buffer size and causing a blocking + // read at the specified byte + async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { + let mut outs = Vec::new(); + + let mut ins = if block_at == 0 { + tokio_test::io::Builder::new() + .wait(Duration::from_millis(10)) + .read(content) + .build() + } else { + tokio_test::io::Builder::new() + .read(&content[..block_at]) + .wait(Duration::from_millis(10)) + .read(&content[block_at..]) + .build() + }; + + let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + + loop { + let buf = decoder + .decode_fut(&mut ins) + .await + .expect("unexpected decode error"); + if buf.is_empty() { + break; // eof + } + outs.extend(buf.as_ref()); + } + + String::from_utf8(outs).expect("decode String") + } + + // iterate over the different ways that this async read could go. + // tests blocking a read at each byte along the content - The shotgun approach + async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + let content_len = content.len(); + for block_at in 0..content_len { + let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await; + assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at); + } + } + + #[tokio::test] + async fn test_read_length_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::length(content.len() as u64)).await; + } + + #[tokio::test] + async fn test_read_chunked_async() { + let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; + let expected = "foobar"; + all_async_cases(content, expected, Decoder::chunked()).await; + } + + #[tokio::test] + async fn test_read_eof_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::eof()).await; + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_chunked_1kb(b: &mut test::Bencher) { + let rt = new_runtime(); + + const LEN: usize = 1024; + let mut vec = Vec::new(); + vec.extend(format!("{:x}\r\n", LEN).as_bytes()); + vec.extend(&[0; LEN][..]); + vec.extend(b"\r\n"); + let content = Bytes::from(vec); + + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::chunked(); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_length_1kb(b: &mut test::Bencher) { + let rt = new_runtime(); + + const LEN: usize = 1024; + let content = Bytes::from(&[0; LEN][..]); + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::length(LEN as u64); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + fn new_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("rt build") + } +} diff --git a/third_party/rust/hyper/src/proto/h1/dispatch.rs b/third_party/rust/hyper/src/proto/h1/dispatch.rs new file mode 100644 index 0000000000..677131bfdd --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/dispatch.rs @@ -0,0 +1,750 @@ +use std::error::Error as StdError; + +use bytes::{Buf, Bytes}; +use http::Request; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace}; + +use super::{Http1Transaction, Wants}; +use crate::body::{Body, DecodedLength, HttpBody}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::proto::{ + BodyLength, Conn, Dispatched, MessageHead, RequestHead, +}; +use crate::upgrade::OnUpgrade; + +pub(crate) struct Dispatcher<D, Bs: HttpBody, I, T> { + conn: Conn<I, Bs::Data, T>, + dispatch: D, + body_tx: Option<crate::body::Sender>, + body_rx: Pin<Box<Option<Bs>>>, + is_closing: bool, +} + +pub(crate) trait Dispatch { + type PollItem; + type PollBody; + type PollError; + type RecvItem; + fn poll_msg( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>; + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>; + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>; + fn should_poll(&self) -> bool; +} + +cfg_server! { + use crate::service::HttpService; + + pub(crate) struct Server<S: HttpService<B>, B> { + in_flight: Pin<Box<Option<S::Future>>>, + pub(crate) service: S, + } +} + +cfg_client! { + pin_project_lite::pin_project! { + pub(crate) struct Client<B> { + callback: Option<crate::client::dispatch::Callback<Request<B>, http::Response<Body>>>, + #[pin] + rx: ClientRx<B>, + rx_closed: bool, + } + } + + type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, http::Response<Body>>; +} + +impl<D, Bs, I, T> Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: HttpBody + 'static, + Bs::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + pub(crate) fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self { + Dispatcher { + conn, + dispatch, + body_tx: None, + body_rx: Box::pin(None), + is_closing: false, + } + } + + #[cfg(feature = "server")] + pub(crate) fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive(); + if self.conn.is_write_closed() { + self.close(); + } + } + + pub(crate) fn into_inner(self) -> (I, Bytes, D) { + let (io, buf) = self.conn.into_inner(); + (io, buf, self.dispatch) + } + + /// Run this dispatcher until HTTP says this connection is done, + /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// + /// This is useful for old-style HTTP upgrades, but ignores + /// newer-style upgrade API. + pub(crate) fn poll_without_shutdown( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> + where + Self: Unpin, + { + Pin::new(self).poll_catch(cx, false).map_ok(|ds| { + if let Dispatched::Upgrade(pending) = ds { + pending.manual(); + } + }) + } + + fn poll_catch( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e))?; + Ok(Dispatched::Shutdown) + })) + } + + fn poll_inner( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + T::update_date(); + + ready!(self.poll_loop(cx))?; + + if self.is_done() { + if let Some(pending) = self.conn.pending_upgrade() { + self.conn.take_error()?; + return Poll::Ready(Ok(Dispatched::Upgrade(pending))); + } else if should_shutdown { + ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?; + } + self.conn.take_error()?; + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Pending + } + } + + fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Limit the looping on this connection, in case it is ready far too + // often, so that other futures don't starve. + // + // 16 was chosen arbitrarily, as that is number of pipelined requests + // benchmarks often use. Perhaps it should be a config option instead. + for _ in 0..16 { + let _ = self.poll_read(cx)?; + let _ = self.poll_write(cx)?; + let _ = self.poll_flush(cx)?; + + // This could happen if reading paused before blocking on IO, + // such as getting to the end of a framed message, but then + // writing/flushing set the state back to Init. In that case, + // if the read buffer still had bytes, we'd want to try poll_read + // again, or else we wouldn't ever be woken up again. + // + // Using this instead of task::current() and notify() inside + // the Conn is noticeably faster in pipelined benchmarks. + if !self.conn.wants_read_again() { + //break; + return Poll::Ready(Ok(())); + } + } + + trace!("poll_loop yielding (self = {:p})", self); + + task::yield_now(cx).map(|never| match never {}) + } + + fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.conn.can_read_head() { + ready!(self.poll_read_head(cx))?; + } else if let Some(mut body) = self.body_tx.take() { + if self.conn.can_read_body() { + match body.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Err(_canceled)) => { + // user doesn't care about the body + // so we should stop reading + trace!("body receiver dropped before eof, draining or closing"); + self.conn.poll_drain_or_close_read(cx); + continue; + } + } + match self.conn.poll_read_body(cx) { + Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + }, + Poll::Ready(None) => { + // just drop, the body will close automatically + } + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => { + body.send_error(crate::Error::new_body(e)); + } + } + } else { + // just drop, the body will close automatically + } + } else { + return self.conn.poll_read_keep_alive(cx); + } + } + } + + fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // can dispatch receive, or does it still care about, an incoming message? + match ready!(self.dispatch.poll_ready(cx)) { + Ok(()) => (), + Err(()) => { + trace!("dispatch no longer receiving messages"); + self.close(); + return Poll::Ready(Ok(())); + } + } + // dispatch is ready for a message, try to read one + match ready!(self.conn.poll_read_head(cx)) { + Some(Ok((mut head, body_len, wants))) => { + let body = match body_len { + DecodedLength::ZERO => Body::empty(), + other => { + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); + self.body_tx = Some(tx); + rx + } + }; + if wants.contains(Wants::UPGRADE) { + let upgrade = self.conn.on_upgrade(); + debug_assert!(!upgrade.is_none(), "empty upgrade"); + debug_assert!(head.extensions.get::<OnUpgrade>().is_none(), "OnUpgrade already set"); + head.extensions.insert(upgrade); + } + self.dispatch.recv_msg(Ok((head, body)))?; + Poll::Ready(Ok(())) + } + Some(Err(err)) => { + debug!("read_head error: {}", err); + self.dispatch.recv_msg(Err(err))?; + // if here, the dispatcher gave the user the error + // somewhere else. we still need to shutdown, but + // not as a second error. + self.close(); + Poll::Ready(Ok(())) + } + None => { + // read eof, the write side will have been closed too unless + // allow_read_close was set to true, in which case just do + // nothing... + debug_assert!(self.conn.is_read_closed()); + if self.conn.is_write_closed() { + self.close(); + } + Poll::Ready(Ok(())) + } + } + } + + fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.body_rx.is_none() + && self.conn.can_write_head() + && self.dispatch.should_poll() + { + if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) { + let (head, mut body) = msg.map_err(crate::Error::new_user_service)?; + + // Check if the body knows its full data immediately. + // + // If so, we can skip a bit of bookkeeping that streaming + // bodies need to do. + if let Some(full) = crate::body::take_full_data(&mut body) { + self.conn.write_full_msg(head, full); + return Poll::Ready(Ok(())); + } + + let body_type = if body.is_end_stream() { + self.body_rx.set(None); + None + } else { + let btype = body + .size_hint() + .exact() + .map(BodyLength::Known) + .or_else(|| Some(BodyLength::Unknown)); + self.body_rx.set(Some(body)); + btype + }; + self.conn.write_head(head, body_type); + } else { + self.close(); + return Poll::Ready(Ok(())); + } + } else if !self.conn.can_buffer_body() { + ready!(self.poll_flush(cx))?; + } else { + // A new scope is needed :( + if let (Some(mut body), clear_body) = + OptGuard::new(self.body_rx.as_mut()).guard_mut() + { + debug_assert!(!*clear_body, "opt guard defaults to keeping body"); + if !self.conn.can_write_body() { + trace!( + "no more write body allowed, user body is_end_stream = {}", + body.is_end_stream(), + ); + *clear_body = true; + continue; + } + + let item = ready!(body.as_mut().poll_data(cx)); + if let Some(item) = item { + let chunk = item.map_err(|e| { + *clear_body = true; + crate::Error::new_user_body(e) + })?; + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body()?; + } else { + self.conn.write_body_and_end(chunk); + } + } else { + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); + } + } else { + *clear_body = true; + self.conn.end_body()?; + } + } else { + return Poll::Pending; + } + } + } + } + + fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.conn.poll_flush(cx).map_err(|err| { + debug!("error writing: {}", err); + crate::Error::new_body_write(err) + }) + } + + fn close(&mut self) { + self.is_closing = true; + self.conn.close_read(); + self.conn.close_write(); + } + + fn is_done(&self) -> bool { + if self.is_closing { + return true; + } + + let read_done = self.conn.is_read_closed(); + + if !T::should_read_first() && read_done { + // a client that cannot read may was well be done. + true + } else { + let write_done = self.conn.is_write_closed() + || (!self.dispatch.should_poll() && self.body_rx.is_none()); + read_done && write_done + } + } +} + +impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: HttpBody + 'static, + Bs::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<Dispatched>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_catch(cx, true) + } +} + +// ===== impl OptGuard ===== + +/// A drop guard to allow a mutable borrow of an Option while being able to +/// set whether the `Option` should be cleared on drop. +struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool); + +impl<'a, T> OptGuard<'a, T> { + fn new(pin: Pin<&'a mut Option<T>>) -> Self { + OptGuard(pin, false) + } + + fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) { + (self.0.as_mut().as_pin_mut(), &mut self.1) + } +} + +impl<'a, T> Drop for OptGuard<'a, T> { + fn drop(&mut self) { + if self.1 { + self.0.set(None); + } + } +} + +// ===== impl Server ===== + +cfg_server! { + impl<S, B> Server<S, B> + where + S: HttpService<B>, + { + pub(crate) fn new(service: S) -> Server<S, B> { + Server { + in_flight: Box::pin(None), + service, + } + } + + pub(crate) fn into_service(self) -> S { + self.service + } + } + + // Service is never pinned + impl<S: HttpService<B>, B> Unpin for Server<S, B> {} + + impl<S, Bs> Dispatch for Server<S, Body> + where + S: HttpService<Body, ResBody = Bs>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bs: HttpBody, + { + type PollItem = MessageHead<http::StatusCode>; + type PollBody = Bs; + type PollError = S::Error; + type RecvItem = RequestHead; + + fn poll_msg( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> { + let mut this = self.as_mut(); + let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() { + let resp = ready!(fut.as_mut().poll(cx)?); + let (parts, body) = resp.into_parts(); + let head = MessageHead { + version: parts.version, + subject: parts.status, + headers: parts.headers, + extensions: parts.extensions, + }; + Poll::Ready(Some(Ok((head, body)))) + } else { + unreachable!("poll_msg shouldn't be called if no inflight"); + }; + + // Since in_flight finished, remove it + this.in_flight.set(None); + ret + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + let (msg, body) = msg?; + let mut req = Request::new(body); + *req.method_mut() = msg.subject.0; + *req.uri_mut() = msg.subject.1; + *req.headers_mut() = msg.headers; + *req.version_mut() = msg.version; + *req.extensions_mut() = msg.extensions; + let fut = self.service.call(req); + self.in_flight.set(Some(fut)); + Ok(()) + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + if self.in_flight.is_some() { + Poll::Pending + } else { + self.service.poll_ready(cx).map_err(|_e| { + // FIXME: return error value. + trace!("service closed"); + }) + } + } + + fn should_poll(&self) -> bool { + self.in_flight.is_some() + } + } +} + +// ===== impl Client ===== + +cfg_client! { + impl<B> Client<B> { + pub(crate) fn new(rx: ClientRx<B>) -> Client<B> { + Client { + callback: None, + rx, + rx_closed: false, + } + } + } + + impl<B> Dispatch for Client<B> + where + B: HttpBody, + { + type PollItem = RequestHead; + type PollBody = B; + type PollError = crate::common::Never; + type RecvItem = crate::proto::ResponseHead; + + fn poll_msg( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), crate::common::Never>>> { + let mut this = self.as_mut(); + debug_assert!(!this.rx_closed); + match this.rx.poll_recv(cx) { + Poll::Ready(Some((req, mut cb))) => { + // check that future hasn't been canceled already + match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("request canceled"); + Poll::Ready(None) + } + Poll::Pending => { + let (parts, body) = req.into_parts(); + let head = RequestHead { + version: parts.version, + subject: crate::proto::RequestLine(parts.method, parts.uri), + headers: parts.headers, + extensions: parts.extensions, + }; + this.callback = Some(cb); + Poll::Ready(Some(Ok((head, body)))) + } + } + } + Poll::Ready(None) => { + // user has dropped sender handle + trace!("client tx closed"); + this.rx_closed = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + match msg { + Ok((msg, body)) => { + if let Some(cb) = self.callback.take() { + let res = msg.into_response(body); + cb.send(Ok(res)); + Ok(()) + } else { + // Getting here is likely a bug! An error should have happened + // in Conn::require_empty_read() before ever parsing a + // full message! + Err(crate::Error::new_unexpected_message()) + } + } + Err(err) => { + if let Some(cb) = self.callback.take() { + cb.send(Err((err, None))); + Ok(()) + } else if !self.rx_closed { + self.rx.close(); + if let Some((req, cb)) = self.rx.try_recv() { + trace!("canceling queued request with connection error: {}", err); + // in this case, the message was never even started, so it's safe to tell + // the user that the request was completely canceled + cb.send(Err((crate::Error::new_canceled().with(err), Some(req)))); + Ok(()) + } else { + Err(err) + } + } else { + Err(err) + } + } + } + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + match self.callback { + Some(ref mut cb) => match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("callback receiver has dropped"); + Poll::Ready(Err(())) + } + Poll::Pending => Poll::Ready(Ok(())), + }, + None => Poll::Ready(Err(())), + } + } + + fn should_poll(&self) -> bool { + self.callback.is_none() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::h1::ClientTransaction; + use std::time::Duration; + + #[test] + fn client_read_bytes_before_writing_request() { + let _ = pretty_env_logger::try_init(); + + tokio_test::task::spawn(()).enter(|cx, _| { + let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle(); + + // Block at 0 for now, but we will release this response before + // the request is ready to write later... + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + // First poll is needed to allow tx to send... + assert!(Pin::new(&mut dispatcher).poll(cx).is_pending()); + + // Unblock our IO, which has a response before we've sent request! + // + handle.read(b"HTTP/1.1 200 OK\r\n\r\n"); + + let mut res_rx = tx + .try_send(crate::Request::new(crate::Body::empty())) + .unwrap(); + + tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx)); + let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx)) + .expect_err("callback should send error"); + + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + other => panic!("expected Canceled, got {:?}", other), + } + }); + } + + #[tokio::test] + async fn client_flushing_is_not_ready_for_next_request() { + let _ = pretty_env_logger::try_init(); + + let (io, _handle) = tokio_test::io::Builder::new() + .write(b"POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\n") + .read(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .wait(std::time::Duration::from_secs(2)) + .build_with_handle(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + conn.set_write_strategy_queue(); + + let dispatcher = Dispatcher::new(Client::new(rx), conn); + let _dispatcher = tokio::spawn(async move { dispatcher.await }); + + let req = crate::Request::builder() + .method("POST") + .body(crate::Body::from("reee")) + .unwrap(); + + let res = tx.try_send(req).unwrap().await.expect("response"); + drop(res); + + assert!(!tx.is_ready()); + } + + #[tokio::test] + async fn body_empty_chunks_ignored() { + let _ = pretty_env_logger::try_init(); + + let io = tokio_test::io::Builder::new() + // no reading or writing, just be blocked for the test... + .wait(Duration::from_secs(5)) + .build(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); + + // First poll is needed to allow tx to send... + assert!(dispatcher.poll().is_pending()); + + let body = { + let (mut tx, body) = crate::Body::channel(); + tx.try_send_data("".into()).unwrap(); + body + }; + + let _res_rx = tx.try_send(crate::Request::new(body)).unwrap(); + + // Ensure conn.write_body wasn't called with the empty chunk. + // If it is, it will trigger an assertion. + assert!(dispatcher.poll().is_pending()); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/encode.rs b/third_party/rust/hyper/src/proto/h1/encode.rs new file mode 100644 index 0000000000..f0aa261a4f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/encode.rs @@ -0,0 +1,439 @@ +use std::fmt; +use std::io::IoSlice; + +use bytes::buf::{Chain, Take}; +use bytes::Buf; +use tracing::trace; + +use super::io::WriteBuf; + +type StaticBuf = &'static [u8]; + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Encoder { + kind: Kind, + is_last: bool, +} + +#[derive(Debug)] +pub(crate) struct EncodedBuf<B> { + kind: BufKind<B>, +} + +#[derive(Debug)] +pub(crate) struct NotEof(u64); + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked, + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when neither Content-Length nor Chunked encoding is set. + /// + /// This is mostly only used with HTTP/1.0 with a length. This kind requires + /// the connection to be closed when the body is finished. + #[cfg(feature = "server")] + CloseDelimited, +} + +#[derive(Debug)] +enum BufKind<B> { + Exact(B), + Limited(Take<B>), + Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), + ChunkedEnd(StaticBuf), +} + +impl Encoder { + fn new(kind: Kind) -> Encoder { + Encoder { + kind, + is_last: false, + } + } + pub(crate) fn chunked() -> Encoder { + Encoder::new(Kind::Chunked) + } + + pub(crate) fn length(len: u64) -> Encoder { + Encoder::new(Kind::Length(len)) + } + + #[cfg(feature = "server")] + pub(crate) fn close_delimited() -> Encoder { + Encoder::new(Kind::CloseDelimited) + } + + pub(crate) fn is_eof(&self) -> bool { + matches!(self.kind, Kind::Length(0)) + } + + #[cfg(feature = "server")] + pub(crate) fn set_last(mut self, is_last: bool) -> Self { + self.is_last = is_last; + self + } + + pub(crate) fn is_last(&self) -> bool { + self.is_last + } + + pub(crate) fn is_close_delimited(&self) -> bool { + match self.kind { + #[cfg(feature = "server")] + Kind::CloseDelimited => true, + _ => false, + } + } + + pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { + match self.kind { + Kind::Length(0) => Ok(None), + Kind::Chunked => Ok(Some(EncodedBuf { + kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), + })), + #[cfg(feature = "server")] + Kind::CloseDelimited => Ok(None), + Kind::Length(n) => Err(NotEof(n)), + } + } + + pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + let kind = match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n" as &'static [u8]); + BufKind::Chunked(buf) + } + Kind::Length(ref mut remaining) => { + trace!("sized write, len = {}", len); + if len as u64 > *remaining { + let limit = *remaining as usize; + *remaining = 0; + BufKind::Limited(msg.take(limit)) + } else { + *remaining -= len as u64; + BufKind::Exact(msg) + } + } + #[cfg(feature = "server")] + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + BufKind::Exact(msg) + } + }; + EncodedBuf { kind } + } + + pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + !self.is_last + } + Kind::Length(remaining) => { + use std::cmp::Ordering; + + trace!("sized write, len = {}", len); + match (len as u64).cmp(&remaining) { + Ordering::Equal => { + dst.buffer(msg); + !self.is_last + } + Ordering::Greater => { + dst.buffer(msg.take(remaining as usize)); + !self.is_last + } + Ordering::Less => { + dst.buffer(msg); + false + } + } + } + #[cfg(feature = "server")] + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + dst.buffer(msg); + false + } + } + } + + /// Encodes the full body, without verifying the remaining length matches. + /// + /// This is used in conjunction with HttpBody::__hyper_full_data(), which + /// means we can trust that the buf has the correct size (the buf itself + /// was checked to make the headers). + pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) + where + B: Buf, + { + debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); + debug_assert!( + match self.kind { + Kind::Length(len) => len == msg.remaining() as u64, + _ => true, + }, + "danger_full_buf length mismatches" + ); + + match self.kind { + Kind::Chunked => { + let len = msg.remaining(); + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + } + _ => { + dst.buffer(msg); + } + } + } +} + +impl<B> Buf for EncodedBuf<B> +where + B: Buf, +{ + #[inline] + fn remaining(&self) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.remaining(), + BufKind::Limited(ref b) => b.remaining(), + BufKind::Chunked(ref b) => b.remaining(), + BufKind::ChunkedEnd(ref b) => b.remaining(), + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match self.kind { + BufKind::Exact(ref b) => b.chunk(), + BufKind::Limited(ref b) => b.chunk(), + BufKind::Chunked(ref b) => b.chunk(), + BufKind::ChunkedEnd(ref b) => b.chunk(), + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match self.kind { + BufKind::Exact(ref mut b) => b.advance(cnt), + BufKind::Limited(ref mut b) => b.advance(cnt), + BufKind::Chunked(ref mut b) => b.advance(cnt), + BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.chunks_vectored(dst), + BufKind::Limited(ref b) => b.chunks_vectored(dst), + BufKind::Chunked(ref b) => b.chunks_vectored(dst), + BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), + } + } +} + +#[cfg(target_pointer_width = "32")] +const USIZE_BYTES: usize = 4; + +#[cfg(target_pointer_width = "64")] +const USIZE_BYTES: usize = 8; + +// each byte will become 2 hex +const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; + +#[derive(Clone, Copy)] +struct ChunkSize { + bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], + pos: u8, + len: u8, +} + +impl ChunkSize { + fn new(len: usize) -> ChunkSize { + use std::fmt::Write; + let mut size = ChunkSize { + bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], + pos: 0, + len: 0, + }; + write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); + size + } +} + +impl Buf for ChunkSize { + #[inline] + fn remaining(&self) -> usize { + (self.len - self.pos).into() + } + + #[inline] + fn chunk(&self) -> &[u8] { + &self.bytes[self.pos.into()..self.len.into()] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining()); + self.pos += cnt as u8; // just asserted cnt fits in u8 + } +} + +impl fmt::Debug for ChunkSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkSize") + .field("bytes", &&self.bytes[..self.len.into()]) + .field("pos", &self.pos) + .finish() + } +} + +impl fmt::Write for ChunkSize { + fn write_str(&mut self, num: &str) -> fmt::Result { + use std::io::Write; + (&mut self.bytes[self.len.into()..]) + .write_all(num.as_bytes()) + .expect("&mut [u8].write() cannot error"); + self.len += num.len() as u8; // safe because bytes is never bigger than 256 + Ok(()) + } +} + +impl<B: Buf> From<B> for EncodedBuf<B> { + fn from(buf: B) -> Self { + EncodedBuf { + kind: BufKind::Exact(buf), + } + } +} + +impl<B: Buf> From<Take<B>> for EncodedBuf<B> { + fn from(buf: Take<B>) -> Self { + EncodedBuf { + kind: BufKind::Limited(buf), + } + } +} + +impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { + fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { + EncodedBuf { + kind: BufKind::Chunked(buf), + } + } +} + +impl fmt::Display for NotEof { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "early end, expected {} more bytes", self.0) + } +} + +impl std::error::Error for NotEof {} + +#[cfg(test)] +mod tests { + use bytes::BufMut; + + use super::super::io::Cursor; + use super::Encoder; + + #[test] + fn chunked() { + let mut encoder = Encoder::chunked(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + assert_eq!(dst, b"7\r\nfoo bar\r\n"); + + let msg2 = b"baz quux herp".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); + + let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); + dst.put(end); + + assert_eq!( + dst, + b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() + ); + } + + #[test] + fn length() { + let max_len = 8; + let mut encoder = Encoder::length(max_len as u64); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst.len(), max_len); + assert_eq!(dst, b"foo barb"); + assert!(encoder.is_eof()); + assert!(encoder.end::<()>().unwrap().is_none()); + } + + #[test] + fn eof() { + let mut encoder = Encoder::close_delimited(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"foo barbaz"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap(); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/io.rs b/third_party/rust/hyper/src/proto/h1/io.rs new file mode 100644 index 0000000000..1d251e2c84 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/io.rs @@ -0,0 +1,1002 @@ +use std::cmp; +use std::fmt; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::future::Future; +use std::io::{self, IoSlice}; +use std::marker::Unpin; +use std::mem::MaybeUninit; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Duration; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; +use tracing::{debug, trace}; + +use super::{Http1Transaction, ParseContext, ParsedMessage}; +use crate::common::buf::BufList; +use crate::common::{task, Pin, Poll}; + +/// The initial buffer size allocated before trying to read from IO. +pub(crate) const INIT_BUFFER_SIZE: usize = 8192; + +/// The minimum value that can be set to max buffer size. +pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; + +/// The default maximum read buffer size. If the buffer gets this big and +/// a message is still not complete, a `TooLarge` error is triggered. +// Note: if this changes, update server::conn::Http::max_buf_size docs. +pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +/// The maximum number of distinct `Buf`s to hold in a list before requiring +/// a flush. Only affects when the buffer strategy is to queue buffers. +/// +/// Note that a flush can happen before reaching the maximum. This simply +/// forces a flush if the queue gets this big. +const MAX_BUF_LIST_BUFFERS: usize = 16; + +pub(crate) struct Buffered<T, B> { + flush_pipeline: bool, + io: T, + read_blocked: bool, + read_buf: BytesMut, + read_buf_strategy: ReadStrategy, + write_buf: WriteBuf<B>, +} + +impl<T, B> fmt::Debug for Buffered<T, B> +where + B: Buf, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Buffered") + .field("read_buf", &self.read_buf) + .field("write_buf", &self.write_buf) + .finish() + } +} + +impl<T, B> Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + pub(crate) fn new(io: T) -> Buffered<T, B> { + let strategy = if io.is_write_vectored() { + WriteStrategy::Queue + } else { + WriteStrategy::Flatten + }; + let write_buf = WriteBuf::new(strategy); + Buffered { + flush_pipeline: false, + io, + read_blocked: false, + read_buf: BytesMut::with_capacity(0), + read_buf_strategy: ReadStrategy::default(), + write_buf, + } + } + + #[cfg(feature = "server")] + pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { + debug_assert!(!self.write_buf.has_remaining()); + self.flush_pipeline = enabled; + if enabled { + self.set_write_strategy_flatten(); + } + } + + pub(crate) fn set_max_buf_size(&mut self, max: usize) { + assert!( + max >= MINIMUM_MAX_BUFFER_SIZE, + "The max_buf_size cannot be smaller than {}.", + MINIMUM_MAX_BUFFER_SIZE, + ); + self.read_buf_strategy = ReadStrategy::with_max(max); + self.write_buf.max_buf_size = max; + } + + #[cfg(feature = "client")] + pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { + self.read_buf_strategy = ReadStrategy::Exact(sz); + } + + pub(crate) fn set_write_strategy_flatten(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Flatten); + } + + pub(crate) fn set_write_strategy_queue(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Queue); + } + + pub(crate) fn read_buf(&self) -> &[u8] { + self.read_buf.as_ref() + } + + #[cfg(test)] + #[cfg(feature = "nightly")] + pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + + /// Return the "allocated" available space, not the potential space + /// that could be allocated in the future. + fn read_buf_remaining_mut(&self) -> usize { + self.read_buf.capacity() - self.read_buf.len() + } + + /// Return whether we can append to the headers buffer. + /// + /// Reasons we can't: + /// - The write buf is in queue mode, and some of the past body is still + /// needing to be flushed. + pub(crate) fn can_headers_buf(&self) -> bool { + !self.write_buf.queue.has_remaining() + } + + pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> { + let buf = self.write_buf.headers_mut(); + &mut buf.bytes + } + + pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { + &mut self.write_buf + } + + pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { + self.write_buf.buffer(buf) + } + + pub(crate) fn can_buffer(&self) -> bool { + self.flush_pipeline || self.write_buf.can_buffer() + } + + pub(crate) fn consume_leading_lines(&mut self) { + if !self.read_buf.is_empty() { + let mut i = 0; + while i < self.read_buf.len() { + match self.read_buf[i] { + b'\r' | b'\n' => i += 1, + _ => break, + } + } + self.read_buf.advance(i); + } + } + + pub(super) fn parse<S>( + &mut self, + cx: &mut task::Context<'_>, + parse_ctx: ParseContext<'_>, + ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> + where + S: Http1Transaction, + { + loop { + match super::role::parse_headers::<S>( + &mut self.read_buf, + ParseContext { + cached_headers: parse_ctx.cached_headers, + req_method: parse_ctx.req_method, + h1_parser_config: parse_ctx.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: parse_ctx.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, + preserve_header_case: parse_ctx.preserve_header_case, + #[cfg(feature = "ffi")] + preserve_header_order: parse_ctx.preserve_header_order, + h09_responses: parse_ctx.h09_responses, + #[cfg(feature = "ffi")] + on_informational: parse_ctx.on_informational, + #[cfg(feature = "ffi")] + raw_headers: parse_ctx.raw_headers, + }, + )? { + Some(msg) => { + debug!("parsed {} headers", msg.head.headers.len()); + + #[cfg(all(feature = "server", feature = "runtime"))] + { + *parse_ctx.h1_header_read_timeout_running = false; + + if let Some(h1_header_read_timeout_fut) = + parse_ctx.h1_header_read_timeout_fut + { + // Reset the timer in order to avoid woken up when the timeout finishes + h1_header_read_timeout_fut + .as_mut() + .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); + } + } + return Poll::Ready(Ok(msg)); + } + None => { + let max = self.read_buf_strategy.max(); + if self.read_buf.len() >= max { + debug!("max_buf_size ({}) reached, closing", max); + return Poll::Ready(Err(crate::Error::new_too_large())); + } + + #[cfg(all(feature = "server", feature = "runtime"))] + if *parse_ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout_fut) = + parse_ctx.h1_header_read_timeout_fut + { + if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { + *parse_ctx.h1_header_read_timeout_running = false; + + tracing::warn!("read header from client timeout"); + return Poll::Ready(Err(crate::Error::new_header_timeout())); + } + } + } + } + } + if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { + trace!("parse eof"); + return Poll::Ready(Err(crate::Error::new_incomplete())); + } + } + } + + pub(crate) fn poll_read_from_io( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<io::Result<usize>> { + self.read_blocked = false; + let next = self.read_buf_strategy.next(); + if self.read_buf_remaining_mut() < next { + self.read_buf.reserve(next); + } + + let dst = self.read_buf.chunk_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + Poll::Ready(Ok(_)) => { + let n = buf.filled().len(); + trace!("received {} bytes", n); + unsafe { + // Safety: we just read that many bytes into the + // uninitialized part of the buffer, so this is okay. + // @tokio pls give me back `poll_read_buf` thanks + self.read_buf.advance_mut(n); + } + self.read_buf_strategy.record(n); + Poll::Ready(Ok(n)) + } + Poll::Pending => { + self.read_blocked = true; + Poll::Pending + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + + pub(crate) fn into_inner(self) -> (T, Bytes) { + (self.io, self.read_buf.freeze()) + } + + pub(crate) fn io_mut(&mut self) -> &mut T { + &mut self.io + } + + pub(crate) fn is_read_blocked(&self) -> bool { + self.read_blocked + } + + pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + if self.flush_pipeline && !self.read_buf.is_empty() { + Poll::Ready(Ok(())) + } else if self.write_buf.remaining() == 0 { + Pin::new(&mut self.io).poll_flush(cx) + } else { + if let WriteStrategy::Flatten = self.write_buf.strategy { + return self.poll_flush_flattened(cx); + } + + const MAX_WRITEV_BUFS: usize = 64; + loop { + let n = { + let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS]; + let len = self.write_buf.chunks_vectored(&mut iovs); + ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))? + }; + // TODO(eliza): we have to do this manually because + // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when + // `poll_write_buf` comes back, the manual advance will need to leave! + self.write_buf.advance(n); + debug!("flushed {} bytes", n); + if self.write_buf.remaining() == 0 { + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + } + + /// Specialized version of `flush` when strategy is Flatten. + /// + /// Since all buffered bytes are flattened into the single headers buffer, + /// that skips some bookkeeping around using multiple buffers. + fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + loop { + let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?; + debug!("flushed {} bytes", n); + self.write_buf.headers.advance(n); + if self.write_buf.headers.remaining() == 0 { + self.write_buf.headers.reset(); + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + + #[cfg(test)] + fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { + futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) + } +} + +// The `B` is a `Buf`, we never project a pin to it +impl<T: Unpin, B> Unpin for Buffered<T, B> {} + +// TODO: This trait is old... at least rename to PollBytes or something... +pub(crate) trait MemRead { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; +} + +impl<T, B> MemRead for Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + if !self.read_buf.is_empty() { + let n = std::cmp::min(len, self.read_buf.len()); + Poll::Ready(Ok(self.read_buf.split_to(n).freeze())) + } else { + let n = ready!(self.poll_read_from_io(cx))?; + Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze())) + } + } +} + +#[derive(Clone, Copy, Debug)] +enum ReadStrategy { + Adaptive { + decrease_now: bool, + next: usize, + max: usize, + }, + #[cfg(feature = "client")] + Exact(usize), +} + +impl ReadStrategy { + fn with_max(max: usize) -> ReadStrategy { + ReadStrategy::Adaptive { + decrease_now: false, + next: INIT_BUFFER_SIZE, + max, + } + } + + fn next(&self) -> usize { + match *self { + ReadStrategy::Adaptive { next, .. } => next, + #[cfg(feature = "client")] + ReadStrategy::Exact(exact) => exact, + } + } + + fn max(&self) -> usize { + match *self { + ReadStrategy::Adaptive { max, .. } => max, + #[cfg(feature = "client")] + ReadStrategy::Exact(exact) => exact, + } + } + + fn record(&mut self, bytes_read: usize) { + match *self { + ReadStrategy::Adaptive { + ref mut decrease_now, + ref mut next, + max, + .. + } => { + if bytes_read >= *next { + *next = cmp::min(incr_power_of_two(*next), max); + *decrease_now = false; + } else { + let decr_to = prev_power_of_two(*next); + if bytes_read < decr_to { + if *decrease_now { + *next = cmp::max(decr_to, INIT_BUFFER_SIZE); + *decrease_now = false; + } else { + // Decreasing is a two "record" process. + *decrease_now = true; + } + } else { + // A read within the current range should cancel + // a potential decrease, since we just saw proof + // that we still need this size. + *decrease_now = false; + } + } + } + #[cfg(feature = "client")] + ReadStrategy::Exact(_) => (), + } + } +} + +fn incr_power_of_two(n: usize) -> usize { + n.saturating_mul(2) +} + +fn prev_power_of_two(n: usize) -> usize { + // Only way this shift can underflow is if n is less than 4. + // (Which would means `usize::MAX >> 64` and underflowed!) + debug_assert!(n >= 4); + (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 +} + +impl Default for ReadStrategy { + fn default() -> ReadStrategy { + ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) + } +} + +#[derive(Clone)] +pub(crate) struct Cursor<T> { + bytes: T, + pos: usize, +} + +impl<T: AsRef<[u8]>> Cursor<T> { + #[inline] + pub(crate) fn new(bytes: T) -> Cursor<T> { + Cursor { bytes, pos: 0 } + } +} + +impl Cursor<Vec<u8>> { + /// If we've advanced the position a bit in this cursor, and wish to + /// extend the underlying vector, we may wish to unshift the "read" bytes + /// off, and move everything else over. + fn maybe_unshift(&mut self, additional: usize) { + if self.pos == 0 { + // nothing to do + return; + } + + if self.bytes.capacity() - self.bytes.len() >= additional { + // there's room! + return; + } + + self.bytes.drain(0..self.pos); + self.pos = 0; + } + + fn reset(&mut self) { + self.pos = 0; + self.bytes.clear(); + } +} + +impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Cursor") + .field("pos", &self.pos) + .field("len", &self.bytes.as_ref().len()) + .finish() + } +} + +impl<T: AsRef<[u8]>> Buf for Cursor<T> { + #[inline] + fn remaining(&self) -> usize { + self.bytes.as_ref().len() - self.pos + } + + #[inline] + fn chunk(&self) -> &[u8] { + &self.bytes.as_ref()[self.pos..] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); + self.pos += cnt; + } +} + +// an internal buffer to collect writes before flushes +pub(super) struct WriteBuf<B> { + /// Re-usable buffer that holds message headers + headers: Cursor<Vec<u8>>, + max_buf_size: usize, + /// Deque of user buffers if strategy is Queue + queue: BufList<B>, + strategy: WriteStrategy, +} + +impl<B: Buf> WriteBuf<B> { + fn new(strategy: WriteStrategy) -> WriteBuf<B> { + WriteBuf { + headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, + queue: BufList::new(), + strategy, + } + } +} + +impl<B> WriteBuf<B> +where + B: Buf, +{ + fn set_strategy(&mut self, strategy: WriteStrategy) { + self.strategy = strategy; + } + + pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { + debug_assert!(buf.has_remaining()); + match self.strategy { + WriteStrategy::Flatten => { + let head = self.headers_mut(); + + head.maybe_unshift(buf.remaining()); + trace!( + self.len = head.remaining(), + buf.len = buf.remaining(), + "buffer.flatten" + ); + //perf: This is a little faster than <Vec as BufMut>>::put, + //but accomplishes the same result. + loop { + let adv = { + let slice = buf.chunk(); + if slice.is_empty() { + return; + } + head.bytes.extend_from_slice(slice); + slice.len() + }; + buf.advance(adv); + } + } + WriteStrategy::Queue => { + trace!( + self.len = self.remaining(), + buf.len = buf.remaining(), + "buffer.queue" + ); + self.queue.push(buf.into()); + } + } + } + + fn can_buffer(&self) -> bool { + match self.strategy { + WriteStrategy::Flatten => self.remaining() < self.max_buf_size, + WriteStrategy::Queue => { + self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size + } + } + } + + fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { + debug_assert!(!self.queue.has_remaining()); + &mut self.headers + } +} + +impl<B: Buf> fmt::Debug for WriteBuf<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WriteBuf") + .field("remaining", &self.remaining()) + .field("strategy", &self.strategy) + .finish() + } +} + +impl<B: Buf> Buf for WriteBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.headers.remaining() + self.queue.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + let headers = self.headers.chunk(); + if !headers.is_empty() { + headers + } else { + self.queue.chunk() + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + let hrem = self.headers.remaining(); + + match hrem.cmp(&cnt) { + cmp::Ordering::Equal => self.headers.reset(), + cmp::Ordering::Greater => self.headers.advance(cnt), + cmp::Ordering::Less => { + let qcnt = cnt - hrem; + self.headers.reset(); + self.queue.advance(qcnt); + } + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + let n = self.headers.chunks_vectored(dst); + self.queue.chunks_vectored(&mut dst[n..]) + n + } +} + +#[derive(Debug)] +enum WriteStrategy { + Flatten, + Queue, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use tokio_test::io::Builder as Mock; + + // #[cfg(feature = "nightly")] + // use test::Bencher; + + /* + impl<T: Read> MemRead for AsyncIo<T> { + fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { + let mut v = vec![0; len]; + let n = try_nb!(self.read(v.as_mut_slice())); + Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + } + } + */ + + #[tokio::test] + #[ignore] + async fn iobuf_write_empty_slice() { + // TODO(eliza): can i have writev back pls T_T + // // First, let's just check that the Mock would normally return an + // // error on an unexpected write, even if the buffer is empty... + // let mut mock = Mock::new().build(); + // futures_util::future::poll_fn(|cx| { + // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) + // }) + // .await + // .expect_err("should be a broken pipe"); + + // // underlying io will return the logic error upon write, + // // so we are testing that the io_buf does not trigger a write + // // when there is nothing to flush + // let mock = Mock::new().build(); + // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + // io_buf.flush().await.expect("should short-circuit flush"); + } + + #[tokio::test] + async fn parse_reads_until_blocked() { + use crate::proto::h1::ClientTransaction; + + let _ = pretty_env_logger::try_init(); + let mock = Mock::new() + // Split over multiple reads will read all of it + .read(b"HTTP/1.1 200 OK\r\n") + .read(b"Server: hyper\r\n") + // missing last line ending + .wait(Duration::from_secs(1)) + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // We expect a `parse` to be not ready, and so can't await it directly. + // Rather, this `poll_fn` will wrap the `Poll` result. + futures_util::future::poll_fn(|cx| { + let parse_ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + assert!(buffered + .parse::<ClientTransaction>(cx, parse_ctx) + .is_pending()); + Poll::Ready(()) + }) + .await; + + assert_eq!( + buffered.read_buf, + b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..] + ); + } + + #[test] + fn read_strategy_adaptive_increments() { + let mut strategy = ReadStrategy::default(); + assert_eq!(strategy.next(), 8192); + + // Grows if record == next + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(16384); + assert_eq!(strategy.next(), 32768); + + // Enormous records still increment at same rate + strategy.record(::std::usize::MAX); + assert_eq!(strategy.next(), 65536); + + let max = strategy.max(); + while strategy.next() < max { + strategy.record(max); + } + + assert_eq!(strategy.next(), max, "never goes over max"); + strategy.record(max + 1); + assert_eq!(strategy.next(), max, "never goes over max"); + } + + #[test] + fn read_strategy_adaptive_decrements() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + strategy.record(8192); + assert_eq!(strategy.next(), 16384, "record was with range"); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "in-range record should make this the 'first' again" + ); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "second smaller record decrements"); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "first doesn't decrement"); + strategy.record(1); + assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum"); + } + + #[test] + fn read_strategy_adaptive_stays_the_same() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "with current step does not decrement" + ); + } + + #[test] + fn read_strategy_adaptive_max_fuzz() { + fn fuzz(max: usize) { + let mut strategy = ReadStrategy::with_max(max); + while strategy.next() < max { + strategy.record(::std::usize::MAX); + } + let mut next = strategy.next(); + while next > 8192 { + strategy.record(1); + strategy.record(1); + next = strategy.next(); + assert!( + next.is_power_of_two(), + "decrement should be powers of two: {} (max = {})", + next, + max, + ); + } + } + + let mut max = 8192; + while max < std::usize::MAX { + fuzz(max); + max = (max / 2).saturating_mul(3); + } + fuzz(::std::usize::MAX); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] // needs to trigger a debug_assert + fn write_buf_requires_non_empty_bufs() { + let mock = Mock::new().build(); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + buffered.buffer(Cursor::new(Vec::new())); + } + + /* + TODO: needs tokio_test::io to allow configure write_buf calls + #[test] + fn write_buf_queue() { + let _ = pretty_env_logger::try_init(); + + let mock = AsyncIo::new_buf(vec![], 1024); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + buffered.flush().unwrap(); + + assert_eq!(buffered.io, b"hello world, it's hyper!"); + assert_eq!(buffered.io.num_writes(), 1); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + */ + + #[tokio::test] + async fn write_buf_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new().write(b"hello world, it's hyper!").build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Flatten); + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + + buffered.flush().await.expect("flush"); + } + + #[test] + fn write_buf_flatten_partially_flushed() { + let _ = pretty_env_logger::try_init(); + + let b = |s: &str| Cursor::new(s.as_bytes().to_vec()); + + let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten); + + write_buf.buffer(b("hello ")); + write_buf.buffer(b("world, ")); + + assert_eq!(write_buf.chunk(), b"hello world, "); + + // advance most of the way, but not all + write_buf.advance(11); + + assert_eq!(write_buf.chunk(), b", "); + assert_eq!(write_buf.headers.pos, 11); + assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE); + + // there's still room in the headers buffer, so just push on the end + write_buf.buffer(b("it's hyper!")); + + assert_eq!(write_buf.chunk(), b", it's hyper!"); + assert_eq!(write_buf.headers.pos, 11); + + let rem1 = write_buf.remaining(); + let cap = write_buf.headers.bytes.capacity(); + + // but when this would go over capacity, don't copy the old bytes + write_buf.buffer(Cursor::new(vec![b'X'; cap])); + assert_eq!(write_buf.remaining(), cap + rem1); + assert_eq!(write_buf.headers.pos, 0); + } + + #[tokio::test] + async fn write_buf_queue_disable_auto() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + .write(b"hello ") + .write(b"world, ") + .write(b"it's ") + .write(b"hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Queue); + + // we have 4 buffers, and vec IO disabled, but explicitly said + // don't try to auto detect (via setting strategy above) + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + // #[cfg(feature = "nightly")] + // #[bench] + // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { + // let s = "Hello, World!"; + // b.bytes = s.len() as u64; + + // let mut write_buf = WriteBuf::<bytes::Bytes>::new(); + // write_buf.set_strategy(WriteStrategy::Flatten); + // b.iter(|| { + // let chunk = bytes::Bytes::from(s); + // write_buf.buffer(chunk); + // ::test::black_box(&write_buf); + // write_buf.headers.bytes.clear(); + // }) + // } +} diff --git a/third_party/rust/hyper/src/proto/h1/mod.rs b/third_party/rust/hyper/src/proto/h1/mod.rs new file mode 100644 index 0000000000..5a2587a843 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/mod.rs @@ -0,0 +1,122 @@ +#[cfg(all(feature = "server", feature = "runtime"))] +use std::{pin::Pin, time::Duration}; + +use bytes::BytesMut; +use http::{HeaderMap, Method}; +use httparse::ParserConfig; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; + +use crate::body::DecodedLength; +use crate::proto::{BodyLength, MessageHead}; + +pub(crate) use self::conn::Conn; +pub(crate) use self::decode::Decoder; +pub(crate) use self::dispatch::Dispatcher; +pub(crate) use self::encode::{EncodedBuf, Encoder}; +//TODO: move out of h1::io +pub(crate) use self::io::MINIMUM_MAX_BUFFER_SIZE; + +mod conn; +mod decode; +pub(crate) mod dispatch; +mod encode; +mod io; +mod role; + +cfg_client! { + pub(crate) type ClientTransaction = role::Client; +} + +cfg_server! { + pub(crate) type ServerTransaction = role::Server; +} + +pub(crate) trait Http1Transaction { + type Incoming; + type Outgoing: Default; + const LOG: &'static str; + fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>; + fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>; + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>; + + fn is_client() -> bool { + !Self::is_server() + } + + fn is_server() -> bool { + !Self::is_client() + } + + fn should_error_on_parse_eof() -> bool { + Self::is_client() + } + + fn should_read_first() -> bool { + Self::is_server() + } + + fn update_date() {} +} + +/// Result newtype for Http1Transaction::parse. +pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>; + +#[derive(Debug)] +pub(crate) struct ParsedMessage<T> { + head: MessageHead<T>, + decode: DecodedLength, + expect_continue: bool, + keep_alive: bool, + wants_upgrade: bool, +} + +pub(crate) struct ParseContext<'a> { + cached_headers: &'a mut Option<HeaderMap>, + req_method: &'a mut Option<Method>, + h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option<Duration>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &'a mut Option<Pin<Box<Sleep>>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &'a mut bool, + preserve_header_case: bool, + #[cfg(feature = "ffi")] + preserve_header_order: bool, + h09_responses: bool, + #[cfg(feature = "ffi")] + on_informational: &'a mut Option<crate::ffi::OnInformational>, + #[cfg(feature = "ffi")] + raw_headers: bool, +} + +/// Passed to Http1Transaction::encode +pub(crate) struct Encode<'a, T> { + head: &'a mut MessageHead<T>, + body: Option<BodyLength>, + #[cfg(feature = "server")] + keep_alive: bool, + req_method: &'a mut Option<Method>, + title_case_headers: bool, +} + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/third_party/rust/hyper/src/proto/h1/role.rs b/third_party/rust/hyper/src/proto/h1/role.rs new file mode 100644 index 0000000000..6252207baf --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/role.rs @@ -0,0 +1,2847 @@ +use std::fmt::{self, Write}; +use std::mem::MaybeUninit; + +use bytes::Bytes; +use bytes::BytesMut; +#[cfg(feature = "server")] +use http::header::ValueIter; +use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::{HeaderMap, Method, StatusCode, Version}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; +use tracing::{debug, error, trace, trace_span, warn}; + +use crate::body::DecodedLength; +#[cfg(feature = "server")] +use crate::common::date; +use crate::error::Parse; +use crate::ext::HeaderCaseMap; +#[cfg(feature = "ffi")] +use crate::ext::OriginalHeaderOrder; +use crate::headers; +use crate::proto::h1::{ + Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, +}; +use crate::proto::{BodyLength, MessageHead, RequestHead, RequestLine}; + +const MAX_HEADERS: usize = 100; +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific +#[cfg(feature = "server")] +const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; + +macro_rules! header_name { + ($bytes:expr) => {{ + { + match HeaderName::from_bytes($bytes) { + Ok(name) => name, + Err(e) => maybe_panic!(e), + } + } + }}; +} + +macro_rules! header_value { + ($bytes:expr) => {{ + { + unsafe { HeaderValue::from_maybe_shared_unchecked($bytes) } + } + }}; +} + +macro_rules! maybe_panic { + ($($arg:tt)*) => ({ + let _err = ($($arg)*); + if cfg!(debug_assertions) { + panic!("{:?}", _err); + } else { + error!("Internal Hyper error, please report {:?}", _err); + return Err(Parse::Internal) + } + }) +} + +pub(super) fn parse_headers<T>( + bytes: &mut BytesMut, + ctx: ParseContext<'_>, +) -> ParseResult<T::Incoming> +where + T: Http1Transaction, +{ + // If the buffer is empty, don't bother entering the span, it's just noise. + if bytes.is_empty() { + return Ok(None); + } + + let span = trace_span!("parse_headers"); + let _s = span.enter(); + + #[cfg(all(feature = "server", feature = "runtime"))] + if !*ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout) = ctx.h1_header_read_timeout { + let deadline = Instant::now() + h1_header_read_timeout; + *ctx.h1_header_read_timeout_running = true; + match ctx.h1_header_read_timeout_fut { + Some(h1_header_read_timeout_fut) => { + debug!("resetting h1 header read timeout timer"); + h1_header_read_timeout_fut.as_mut().reset(deadline); + } + None => { + debug!("setting h1 header read timeout timer"); + *ctx.h1_header_read_timeout_fut = + Some(Box::pin(tokio::time::sleep_until(deadline))); + } + } + } + } + + T::parse(bytes, ctx) +} + +pub(super) fn encode_headers<T>( + enc: Encode<'_, T::Outgoing>, + dst: &mut Vec<u8>, +) -> crate::Result<Encoder> +where + T: Http1Transaction, +{ + let span = trace_span!("encode_headers"); + let _s = span.enter(); + T::encode(enc, dst) +} + +// There are 2 main roles, Client and Server. + +#[cfg(feature = "client")] +pub(crate) enum Client {} + +#[cfg(feature = "server")] +pub(crate) enum Server {} + +#[cfg(feature = "server")] +impl Http1Transaction for Server { + type Incoming = RequestLine; + type Outgoing = StatusCode; + const LOG: &'static str = "{role=server}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<RequestLine> { + debug_assert!(!buf.is_empty(), "parse called with empty buf"); + + let mut keep_alive; + let is_http_11; + let subject; + let version; + let len; + let headers_len; + + // Unsafe: both headers_indices and headers are using uninitialized memory, + // but we *never* read any of it until after httparse has assigned + // values into it. By not zeroing out the stack memory, this saves + // a good ~5% on pipeline benchmarks. + let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + MaybeUninit::uninit().assume_init() + }; + { + /* SAFETY: it is safe to go from MaybeUninit array to array of MaybeUninit */ + let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + trace!(bytes = buf.len(), "Request.parse"); + let mut req = httparse::Request::new(&mut []); + let bytes = buf.as_ref(); + match req.parse_with_uninit_headers(bytes, &mut headers) { + Ok(httparse::Status::Complete(parsed_len)) => { + trace!("Request.parse Complete({})", parsed_len); + len = parsed_len; + let uri = req.path.unwrap(); + if uri.len() > MAX_URI_LEN { + return Err(Parse::UriTooLong); + } + subject = RequestLine( + Method::from_bytes(req.method.unwrap().as_bytes())?, + uri.parse()?, + ); + version = if req.version.unwrap() == 1 { + keep_alive = true; + is_http_11 = true; + Version::HTTP_11 + } else { + keep_alive = false; + is_http_11 = false; + Version::HTTP_10 + }; + + record_header_indices(bytes, &req.headers, &mut headers_indices)?; + headers_len = req.headers.len(); + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(err) => { + return Err(match err { + // if invalid Token, try to determine if for method or path + httparse::Error::Token => { + if req.method.is_none() { + Parse::Method + } else { + debug_assert!(req.path.is_none()); + Parse::Uri + } + } + other => other.into(), + }); + } + } + }; + + let slice = buf.split_to(len).freeze(); + + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. (irrelevant to Request) + // 2. (irrelevant to Request) + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Length 0. + // 7. (irrelevant to Request) + + let mut decoder = DecodedLength::ZERO; + let mut expect_continue = false; + let mut con_len = None; + let mut is_te = false; + let mut is_te_chunked = false; + let mut wants_upgrade = subject.0 == Method::CONNECT; + + let mut header_case_map = if ctx.preserve_header_case { + Some(HeaderCaseMap::default()) + } else { + None + }; + + #[cfg(feature = "ffi")] + let mut header_order = if ctx.preserve_header_order { + Some(OriginalHeaderOrder::default()) + } else { + None + }; + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + match name { + header::TRANSFER_ENCODING => { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if !is_http_11 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + return Err(Parse::transfer_encoding_unexpected()); + } + is_te = true; + if headers::is_chunked_(&value) { + is_te_chunked = true; + decoder = DecodedLength::CHUNKED; + } else { + is_te_chunked = false; + } + } + header::CONTENT_LENGTH => { + if is_te { + continue; + } + let len = headers::content_length_parse(&value) + .ok_or_else(Parse::content_length_invalid)?; + if let Some(prev) = con_len { + if prev != len { + debug!( + "multiple Content-Length headers with different values: [{}, {}]", + prev, len, + ); + return Err(Parse::content_length_invalid()); + } + // we don't need to append this secondary length + continue; + } + decoder = DecodedLength::checked_new(len)?; + con_len = Some(len); + } + header::CONNECTION => { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + header::EXPECT => { + // According to https://datatracker.ietf.org/doc/html/rfc2616#section-14.20 + // Comparison of expectation values is case-insensitive for unquoted tokens + // (including the 100-continue token) + expect_continue = value.as_bytes().eq_ignore_ascii_case(b"100-continue"); + } + header::UPGRADE => { + // Upgrades are only allowed with HTTP/1.1 + wants_upgrade = is_http_11; + } + + _ => (), + } + + if let Some(ref mut header_case_map) = header_case_map { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + + #[cfg(feature = "ffi")] + if let Some(ref mut header_order) = header_order { + header_order.append(&name); + } + + headers.append(name, value); + } + + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::transfer_encoding_invalid()); + } + + let mut extensions = http::Extensions::default(); + + if let Some(header_case_map) = header_case_map { + extensions.insert(header_case_map); + } + + #[cfg(feature = "ffi")] + if let Some(header_order) = header_order { + extensions.insert(header_order); + } + + *ctx.req_method = Some(subject.0.clone()); + + Ok(Some(ParsedMessage { + head: MessageHead { + version, + subject, + headers, + extensions, + }, + decode: decoder, + expect_continue, + keep_alive, + wants_upgrade, + })) + } + + fn encode(mut msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Server::encode status={:?}, body={:?}, req_method={:?}", + msg.head.subject, + msg.body, + msg.req_method + ); + + let mut wrote_len = false; + + // hyper currently doesn't support returning 1xx status codes as a Response + // This is because Service only allows returning a single Response, and + // so if you try to reply with a e.g. 100 Continue, you have no way of + // replying with the latter status code response. + let (ret, is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + (Ok(()), true) + } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() { + // Sending content-length or transfer-encoding header on 2xx response + // to CONNECT is forbidden in RFC 7231. + wrote_len = true; + (Ok(()), true) + } else if msg.head.subject.is_informational() { + warn!("response with 1xx status code not supported"); + *msg.head = MessageHead::default(); + msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; + (Err(crate::Error::new_user_unsupported_status_code()), true) + } else { + (Ok(()), !msg.keep_alive) + }; + + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let orig_len = dst.len(); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + let custom_reason_phrase = msg.head.extensions.get::<crate::ext::ReasonPhrase>(); + + if msg.head.version == Version::HTTP_11 + && msg.head.subject == StatusCode::OK + && custom_reason_phrase.is_none() + { + extend(dst, b"HTTP/1.1 200 OK\r\n"); + } else { + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0 "), + Version::HTTP_11 => extend(dst, b"HTTP/1.1 "), + Version::HTTP_2 => { + debug!("response with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1 "); + } + other => panic!("unexpected response version: {:?}", other), + } + + extend(dst, msg.head.subject.as_str().as_bytes()); + extend(dst, b" "); + + if let Some(reason) = custom_reason_phrase { + extend(dst, reason.as_bytes()); + } else { + // a reason MUST be written, as many parsers will expect it. + extend( + dst, + msg.head + .subject + .canonical_reason() + .unwrap_or("<none>") + .as_bytes(), + ); + } + + extend(dst, b"\r\n"); + } + + let orig_headers; + let extensions = std::mem::take(&mut msg.head.extensions); + let orig_headers = match extensions.get::<HeaderCaseMap>() { + None if msg.title_case_headers => { + orig_headers = HeaderCaseMap::default(); + Some(&orig_headers) + } + orig_headers => orig_headers, + }; + let encoder = if let Some(orig_headers) = orig_headers { + Self::encode_headers_with_original_case( + msg, + dst, + is_last, + orig_len, + wrote_len, + orig_headers, + )? + } else { + Self::encode_headers_with_lower_case(msg, dst, is_last, orig_len, wrote_len)? + }; + + ret.map(|()| encoder) + } + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + use crate::error::Kind; + let status = match *err.kind() { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header(_)) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG, + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + + fn is_server() -> bool { + true + } + + fn update_date() { + date::update(); + } +} + +#[cfg(feature = "server")] +impl Server { + fn can_have_body(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_chunked(method, status) + } + + fn can_chunked(method: &Option<Method>, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() + { + false + } else if status.is_informational() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn can_have_content_length(method: &Option<Method>, status: StatusCode) -> bool { + if status.is_informational() || method == &Some(Method::CONNECT) && status.is_success() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn can_have_implicit_zero_content_length(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_have_content_length(method, status) && method != &Some(Method::HEAD) + } + + fn encode_headers_with_lower_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + is_last: bool, + orig_len: usize, + wrote_len: bool, + ) -> crate::Result<Encoder> { + struct LowercaseWriter; + + impl HeaderNameWriter for LowercaseWriter { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + line: &str, + _: (HeaderName, &str), + ) { + extend(dst, line.as_bytes()) + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + name_with_colon: &str, + _: HeaderName, + ) { + extend(dst, name_with_colon.as_bytes()) + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) { + extend(dst, name.as_str().as_bytes()) + } + } + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, LowercaseWriter) + } + + #[cold] + #[inline(never)] + fn encode_headers_with_original_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + is_last: bool, + orig_len: usize, + wrote_len: bool, + orig_headers: &HeaderCaseMap, + ) -> crate::Result<Encoder> { + struct OrigCaseWriter<'map> { + map: &'map HeaderCaseMap, + current: Option<(HeaderName, ValueIter<'map, Bytes>)>, + title_case_headers: bool, + } + + impl HeaderNameWriter for OrigCaseWriter<'_> { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + _: &str, + (name, rest): (HeaderName, &str), + ) { + self.write_header_name(dst, &name); + extend(dst, rest.as_bytes()); + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + _: &str, + name: HeaderName, + ) { + self.write_header_name(dst, &name); + extend(dst, b": "); + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) { + let Self { + map, + ref mut current, + title_case_headers, + } = *self; + if current.as_ref().map_or(true, |(last, _)| last != name) { + *current = None; + } + let (_, values) = + current.get_or_insert_with(|| (name.clone(), map.get_all_internal(name))); + + if let Some(orig_name) = values.next() { + extend(dst, orig_name); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); + } else { + extend(dst, name.as_str().as_bytes()); + } + } + } + + let header_name_writer = OrigCaseWriter { + map: orig_headers, + current: None, + title_case_headers: msg.title_case_headers, + }; + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, header_name_writer) + } + + #[inline] + fn encode_headers<W>( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + mut is_last: bool, + orig_len: usize, + mut wrote_len: bool, + mut header_name_writer: W, + ) -> crate::Result<Encoder> + where + W: HeaderNameWriter, + { + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let rewind = |dst: &mut Vec<u8>| { + dst.truncate(orig_len); + }; + + let mut encoder = Encoder::length(0); + let mut wrote_date = false; + let mut cur_name = None; + let mut is_name_written = false; + let mut must_write_chunked = false; + let mut prev_con_len = None; + + macro_rules! handle_is_name_written { + () => {{ + if is_name_written { + // we need to clean up and write the newline + debug_assert_ne!( + &dst[dst.len() - 2..], + b"\r\n", + "previous header wrote newline but set is_name_written" + ); + + if must_write_chunked { + extend(dst, b", chunked\r\n"); + } else { + extend(dst, b"\r\n"); + } + } + }}; + } + + 'headers: for (opt_name, value) in msg.head.headers.drain() { + if let Some(n) = opt_name { + cur_name = Some(n); + handle_is_name_written!(); + is_name_written = false; + } + let name = cur_name.as_ref().expect("current header name"); + match *name { + header::CONTENT_LENGTH => { + if wrote_len && !is_name_written { + warn!("unexpected content-length found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + match msg.body { + Some(BodyLength::Known(known_len)) => { + // The HttpBody claims to know a length, and + // the headers are already set. For performance + // reasons, we are just going to trust that + // the values match. + // + // In debug builds, we'll assert they are the + // same to help developers find bugs. + #[cfg(debug_assertions)] + { + if let Some(len) = headers::content_length_parse(&value) { + assert!( + len == known_len, + "payload claims content-length of {}, custom content-length header claims {}", + known_len, + len, + ); + } + } + + if !is_name_written { + encoder = Encoder::length(known_len); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + } + continue 'headers; + } + Some(BodyLength::Unknown) => { + // The HttpBody impl didn't know how long the + // body is, but a length header was included. + // We have to parse the value to return our + // Encoder... + + if let Some(len) = headers::content_length_parse(&value) { + if let Some(prev) = prev_con_len { + if prev != len { + warn!( + "multiple Content-Length values found: [{}, {}]", + prev, len + ); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + debug_assert!(is_name_written); + continue 'headers; + } else { + // we haven't written content-length yet! + encoder = Encoder::length(len); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + prev_con_len = Some(len); + continue 'headers; + } + } else { + warn!("illegal Content-Length value: {:?}", value); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + } + None => { + // We have no body to actually send, + // but the headers claim a content-length. + // There's only 2 ways this makes sense: + // + // - The header says the length is `0`. + // - This is a response to a `HEAD` request. + if msg.req_method == &Some(Method::HEAD) { + debug_assert_eq!(encoder, Encoder::length(0)); + } else { + if value.as_bytes() != b"0" { + warn!( + "content-length value found, but empty body provided: {:?}", + value + ); + } + continue 'headers; + } + } + } + wrote_len = true; + } + header::TRANSFER_ENCODING => { + if wrote_len && !is_name_written { + warn!("unexpected transfer-encoding found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + wrote_len = true; + // Must check each value, because `chunked` needs to be the + // last encoding, or else we add it. + must_write_chunked = !headers::is_chunked_(&value); + + if !is_name_written { + encoder = Encoder::chunked(); + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "transfer-encoding: ", + header::TRANSFER_ENCODING, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::CONNECTION => { + if !is_last && headers::connection_close(&value) { + is_last = true; + } + if !is_name_written { + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "connection: ", + header::CONNECTION, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::DATE => { + wrote_date = true; + } + _ => (), + } + //TODO: this should perhaps instead combine them into + //single lines, as RFC7230 suggests is preferable. + + // non-special write Name and Value + debug_assert!( + !is_name_written, + "{:?} set is_name_written and didn't continue loop", + name, + ); + header_name_writer.write_header_name(dst, name); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + + handle_is_name_written!(); + + if !wrote_len { + encoder = match msg.body { + Some(BodyLength::Unknown) => { + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + Encoder::close_delimited() + } else { + header_name_writer.write_full_header_line( + dst, + "transfer-encoding: chunked\r\n", + (header::TRANSFER_ENCODING, ": chunked\r\n"), + ); + Encoder::chunked() + } + } + None | Some(BodyLength::Known(0)) => { + if Server::can_have_implicit_zero_content_length( + msg.req_method, + msg.head.subject, + ) { + header_name_writer.write_full_header_line( + dst, + "content-length: 0\r\n", + (header::CONTENT_LENGTH, ": 0\r\n"), + ) + } + Encoder::length(0) + } + Some(BodyLength::Known(len)) => { + if !Server::can_have_content_length(msg.req_method, msg.head.subject) { + Encoder::length(0) + } else { + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, ::itoa::Buffer::new().format(len).as_bytes()); + extend(dst, b"\r\n"); + Encoder::length(len) + } + } + }; + } + + if !Server::can_have_body(msg.req_method, msg.head.subject) { + trace!( + "server body forced to 0; method={:?}, status={:?}", + msg.req_method, + msg.head.subject + ); + encoder = Encoder::length(0); + } + + // cached date is much faster than formatting every request + if !wrote_date { + dst.reserve(date::DATE_VALUE_LENGTH + 8); + header_name_writer.write_header_name_with_colon(dst, "date: ", header::DATE); + date::extend(dst); + extend(dst, b"\r\n\r\n"); + } else { + extend(dst, b"\r\n"); + } + + Ok(encoder.set_last(is_last)) + } +} + +#[cfg(feature = "server")] +trait HeaderNameWriter { + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + line: &str, + name_value_pair: (HeaderName, &str), + ); + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + name_with_colon: &str, + name: HeaderName, + ); + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName); +} + +#[cfg(feature = "client")] +impl Http1Transaction for Client { + type Incoming = StatusCode; + type Outgoing = RequestLine; + const LOG: &'static str = "{role=client}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> { + debug_assert!(!buf.is_empty(), "parse called with empty buf"); + + // Loop to skip information status code headers (100 Continue, etc). + loop { + // Unsafe: see comment in Server Http1Transaction, above. + let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + MaybeUninit::uninit().assume_init() + }; + let (len, status, reason, version, headers_len) = { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + trace!(bytes = buf.len(), "Response.parse"); + let mut res = httparse::Response::new(&mut []); + let bytes = buf.as_ref(); + match ctx.h1_parser_config.parse_response_with_uninit_headers( + &mut res, + bytes, + &mut headers, + ) { + Ok(httparse::Status::Complete(len)) => { + trace!("Response.parse Complete({})", len); + let status = StatusCode::from_u16(res.code.unwrap())?; + + let reason = { + let reason = res.reason.unwrap(); + // Only save the reason phrase if it isn't the canonical reason + if Some(reason) != status.canonical_reason() { + Some(Bytes::copy_from_slice(reason.as_bytes())) + } else { + None + } + }; + + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + record_header_indices(bytes, &res.headers, &mut headers_indices)?; + let headers_len = res.headers.len(); + (len, status, reason, version, headers_len) + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(httparse::Error::Version) if ctx.h09_responses => { + trace!("Response.parse accepted HTTP/0.9 response"); + + (0, StatusCode::OK, None, Version::HTTP_09, 0) + } + Err(e) => return Err(e.into()), + } + }; + + let mut slice = buf.split_to(len); + + if ctx + .h1_parser_config + .obsolete_multiline_headers_in_responses_are_allowed() + { + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + for b in &mut slice[header.value.0..header.value.1] { + if *b == b'\r' || *b == b'\n' { + *b = b' '; + } + } + } + } + + let slice = slice.freeze(); + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + let mut keep_alive = version == Version::HTTP_11; + + let mut header_case_map = if ctx.preserve_header_case { + Some(HeaderCaseMap::default()) + } else { + None + }; + + #[cfg(feature = "ffi")] + let mut header_order = if ctx.preserve_header_order { + Some(OriginalHeaderOrder::default()) + } else { + None + }; + + headers.reserve(headers_len); + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + if let header::CONNECTION = name { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + + if let Some(ref mut header_case_map) = header_case_map { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + + #[cfg(feature = "ffi")] + if let Some(ref mut header_order) = header_order { + header_order.append(&name); + } + + headers.append(name, value); + } + + let mut extensions = http::Extensions::default(); + + if let Some(header_case_map) = header_case_map { + extensions.insert(header_case_map); + } + + #[cfg(feature = "ffi")] + if let Some(header_order) = header_order { + extensions.insert(header_order); + } + + if let Some(reason) = reason { + // Safety: httparse ensures that only valid reason phrase bytes are present in this + // field. + let reason = unsafe { crate::ext::ReasonPhrase::from_bytes_unchecked(reason) }; + extensions.insert(reason); + } + + #[cfg(feature = "ffi")] + if ctx.raw_headers { + extensions.insert(crate::ffi::RawHeaders(crate::ffi::hyper_buf(slice))); + } + + let head = MessageHead { + version, + subject: status, + headers, + extensions, + }; + if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { + return Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + // a client upgrade means the connection can't be used + // again, as it is definitely upgrading. + keep_alive: keep_alive && !is_upgrade, + wants_upgrade: is_upgrade, + })); + } + + #[cfg(feature = "ffi")] + if head.subject.is_informational() { + if let Some(callback) = ctx.on_informational { + callback.call(head.into_response(crate::Body::empty())); + } + } + + // Parsing a 1xx response could have consumed the buffer, check if + // it is empty now... + if buf.is_empty() { + return Ok(None); + } + } + } + + fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Client::encode method={:?}, body={:?}", + msg.head.subject.0, + msg.body + ); + + *msg.req_method = Some(msg.head.subject.0.clone()); + + let body = Client::set_length(msg.head, msg.body); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + extend(dst, msg.head.subject.0.as_str().as_bytes()); + extend(dst, b" "); + //TODO: add API to http::Uri to encode without std::fmt + let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1); + + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0"), + Version::HTTP_11 => extend(dst, b"HTTP/1.1"), + Version::HTTP_2 => { + debug!("request with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1"); + } + other => panic!("unexpected request version: {:?}", other), + } + extend(dst, b"\r\n"); + + if let Some(orig_headers) = msg.head.extensions.get::<HeaderCaseMap>() { + write_headers_original_case( + &msg.head.headers, + orig_headers, + dst, + msg.title_case_headers, + ); + } else if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + + extend(dst, b"\r\n"); + msg.head.headers.clear(); //TODO: remove when switching to drain() + + Ok(body) + } + + fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + // we can't tell the server about any errors it creates + None + } + + fn is_client() -> bool { + true + } +} + +#[cfg(feature = "client")] +impl Client { + /// Returns Some(length, wants_upgrade) if successful. + /// + /// Returns None if this message head should be skipped (like a 100 status). + fn decoder( + inc: &MessageHead<StatusCode>, + method: &mut Option<Method>, + ) -> Result<Option<(DecodedLength, bool)>, Parse> { + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. (irrelevant to Response) + // 7. Read till EOF. + + match inc.subject.as_u16() { + 101 => { + return Ok(Some((DecodedLength::ZERO, true))); + } + 100 | 102..=199 => { + trace!("ignoring informational response: {}", inc.subject.as_u16()); + return Ok(None); + } + 204 | 304 => return Ok(Some((DecodedLength::ZERO, false))), + _ => (), + } + match *method { + Some(Method::HEAD) => { + return Ok(Some((DecodedLength::ZERO, false))); + } + Some(Method::CONNECT) => { + if let 200..=299 = inc.subject.as_u16() { + return Ok(Some((DecodedLength::ZERO, true))); + } + } + Some(_) => {} + None => { + trace!("Client::decoder is missing the Method"); + } + } + + if inc.headers.contains_key(header::TRANSFER_ENCODING) { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if inc.version == Version::HTTP_10 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + Err(Parse::transfer_encoding_unexpected()) + } else if headers::transfer_encoding_is_chunked(&inc.headers) { + Ok(Some((DecodedLength::CHUNKED, false))) + } else { + trace!("not chunked, read till eof"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { + Ok(Some((DecodedLength::checked_new(len)?, false))) + } else if inc.headers.contains_key(header::CONTENT_LENGTH) { + debug!("illegal Content-Length header"); + Err(Parse::content_length_invalid()) + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } + fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder { + let body = if let Some(body) = body { + body + } else { + head.headers.remove(header::TRANSFER_ENCODING); + return Encoder::length(0); + }; + + // HTTP/1.0 doesn't know about chunked + let can_chunked = head.version == Version::HTTP_11; + let headers = &mut head.headers; + + // If the user already set specific headers, we should respect them, regardless + // of what the HttpBody knows about itself. They set them for a reason. + + // Because of the borrow checker, we can't check the for an existing + // Content-Length header while holding an `Entry` for the Transfer-Encoding + // header, so unfortunately, we must do the check here, first. + + let existing_con_len = headers::content_length_parse_all(headers); + let mut should_remove_con_len = false; + + if !can_chunked { + // Chunked isn't legal, so if it is set, we need to remove it. + if headers.remove(header::TRANSFER_ENCODING).is_some() { + trace!("removing illegal transfer-encoding header"); + } + + return if let Some(len) = existing_con_len { + Encoder::length(len) + } else if let BodyLength::Known(len) = body { + set_content_length(headers, len) + } else { + // HTTP/1.0 client requests without a content-length + // cannot have any body at all. + Encoder::length(0) + }; + } + + // If the user set a transfer-encoding, respect that. Let's just + // make sure `chunked` is the final encoding. + let encoder = match headers.entry(header::TRANSFER_ENCODING) { + Entry::Occupied(te) => { + should_remove_con_len = true; + if headers::is_chunked(te.iter()) { + Some(Encoder::chunked()) + } else { + warn!("user provided transfer-encoding does not end in 'chunked'"); + + // There's a Transfer-Encoding, but it doesn't end in 'chunked'! + // An example that could trigger this: + // + // Transfer-Encoding: gzip + // + // This can be bad, depending on if this is a request or a + // response. + // + // - A request is illegal if there is a `Transfer-Encoding` + // but it doesn't end in `chunked`. + // - A response that has `Transfer-Encoding` but doesn't + // end in `chunked` isn't illegal, it just forces this + // to be close-delimited. + // + // We can try to repair this, by adding `chunked` ourselves. + + headers::add_chunked(te); + Some(Encoder::chunked()) + } + } + Entry::Vacant(te) => { + if let Some(len) = existing_con_len { + Some(Encoder::length(len)) + } else if let BodyLength::Unknown = body { + // GET, HEAD, and CONNECT almost never have bodies. + // + // So instead of sending a "chunked" body with a 0-chunk, + // assume no body here. If you *must* send a body, + // set the headers explicitly. + match head.subject.0 { + Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)), + _ => { + te.insert(HeaderValue::from_static("chunked")); + Some(Encoder::chunked()) + } + } + } else { + None + } + } + }; + + // This is because we need a second mutable borrow to remove + // content-length header. + if let Some(encoder) = encoder { + if should_remove_con_len && existing_con_len.is_some() { + headers.remove(header::CONTENT_LENGTH); + } + return encoder; + } + + // User didn't set transfer-encoding, AND we know body length, + // so we can just set the Content-Length automatically. + + let len = if let BodyLength::Known(len) = body { + len + } else { + unreachable!("BodyLength::Unknown would set chunked"); + }; + + set_content_length(headers, len) + } +} + +fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { + // At this point, there should not be a valid Content-Length + // header. However, since we'll be indexing in anyways, we can + // warn the user if there was an existing illegal header. + // + // Or at least, we can in theory. It's actually a little bit slower, + // so perhaps only do that while the user is developing/testing. + + if cfg!(debug_assertions) { + match headers.entry(header::CONTENT_LENGTH) { + Entry::Occupied(mut cl) => { + // Internal sanity check, we should have already determined + // that the header was illegal before calling this function. + debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none()); + // Uh oh, the user set `Content-Length` headers, but set bad ones. + // This would be an illegal message anyways, so let's try to repair + // with our known good length. + error!("user provided content-length header was invalid"); + + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + Entry::Vacant(cl) => { + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + } + } else { + headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len)); + Encoder::length(len) + } +} + +#[derive(Clone, Copy)] +struct HeaderIndices { + name: (usize, usize), + value: (usize, usize), +} + +fn record_header_indices( + bytes: &[u8], + headers: &[httparse::Header<'_>], + indices: &mut [MaybeUninit<HeaderIndices>], +) -> Result<(), crate::error::Parse> { + let bytes_ptr = bytes.as_ptr() as usize; + + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + if header.name.len() >= (1 << 16) { + debug!("header name larger than 64kb: {:?}", header.name); + return Err(crate::error::Parse::TooLarge); + } + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + + // FIXME(maybe_uninit_extra) + // FIXME(addr_of) + // Currently we don't have `ptr::addr_of_mut` in stable rust or + // MaybeUninit::write, so this is some way of assigning into a MaybeUninit + // safely + let new_header_indices = HeaderIndices { + name: (name_start, name_end), + value: (value_start, value_end), + }; + *indices = MaybeUninit::new(new_header_indices); + } + + Ok(()) +} + +// Write header names as title case. The header name is assumed to be ASCII. +fn title_case(dst: &mut Vec<u8>, name: &[u8]) { + dst.reserve(name.len()); + + // Ensure first character is uppercased + let mut prev = b'-'; + for &(mut c) in name { + if prev == b'-' { + c.make_ascii_uppercase(); + } + dst.push(c); + prev = c; + } +} + +fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + title_case(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +#[cold] +fn write_headers_original_case( + headers: &HeaderMap, + orig_case: &HeaderCaseMap, + dst: &mut Vec<u8>, + title_case_headers: bool, +) { + // For each header name/value pair, there may be a value in the casemap + // that corresponds to the HeaderValue. So, we iterator all the keys, + // and for each one, try to pair the originally cased name with the value. + // + // TODO: consider adding http::HeaderMap::entries() iterator + for name in headers.keys() { + let mut names = orig_case.get_all(name); + + for value in headers.get_all(name) { + if let Some(orig_name) = names.next() { + extend(dst, orig_name.as_ref()); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); + } else { + extend(dst, name.as_str().as_bytes()); + } + + // Wanted for curl test cases that send `X-Custom-Header:\r\n` + if value.is_empty() { + extend(dst, b":\r\n"); + } else { + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + } + } +} + +struct FastWrite<'a>(&'a mut Vec<u8>); + +impl<'a> fmt::Write for FastWrite<'a> { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + extend(self.0, s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +fn extend(dst: &mut Vec<u8>, data: &[u8]) { + dst.extend_from_slice(data); +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + + use super::*; + + #[test] + fn test_parse_request() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut method = None; + let msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut None, + req_method: &mut method, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject.0, crate::Method::GET); + assert_eq!(msg.head.subject.1, "/echo"); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Host"], "hyper.rs"); + assert_eq!(method, Some(crate::Method::GET)); + } + + #[test] + fn test_parse_response() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Content-Length"], "0"); + } + + #[test] + fn test_parse_request_errors() { + let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Server::parse(&mut raw, ctx).unwrap_err(); + } + + const H09_RESPONSE: &'static str = "Baguettes are super delicious, don't you agree?"; + + #[test] + fn test_parse_response_h09_allowed() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(H09_RESPONSE); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: true, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw, H09_RESPONSE); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_09); + assert_eq!(msg.head.headers.len(), 0); + } + + #[test] + fn test_parse_response_h09_rejected() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(H09_RESPONSE); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Client::parse(&mut raw, ctx).unwrap_err(); + assert_eq!(raw, H09_RESPONSE); + } + + const RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON: &'static str = + "HTTP/1.1 200 OK\r\nAccess-Control-Allow-Credentials : true\r\n\r\n"; + + #[test] + fn test_parse_allow_response_with_spaces_before_colons() { + use httparse::ParserConfig; + + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON); + let mut h1_parser_config = ParserConfig::default(); + h1_parser_config.allow_spaces_after_header_name_in_responses(true); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config, + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Access-Control-Allow-Credentials"], "true"); + } + + #[test] + fn test_parse_reject_response_with_spaces_before_colons() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Client::parse(&mut raw, ctx).unwrap_err(); + } + + #[test] + fn test_parse_preserve_header_case_in_request() { + let mut raw = + BytesMut::from("GET / HTTP/1.1\r\nHost: hyper.rs\r\nX-BREAD: baguette\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: true, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); + let orig_headers = parsed_message + .head + .extensions + .get::<HeaderCaseMap>() + .unwrap(); + assert_eq!( + orig_headers + .get_all_internal(&HeaderName::from_static("host")) + .into_iter() + .collect::<Vec<_>>(), + vec![&Bytes::from("Host")] + ); + assert_eq!( + orig_headers + .get_all_internal(&HeaderName::from_static("x-bread")) + .into_iter() + .collect::<Vec<_>>(), + vec![&Bytes::from("X-BREAD")] + ); + } + + #[test] + fn test_decoder_request() { + fn parse(s: &str) -> ParsedMessage<RequestLine> { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str, comment: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect_err(comment) + } + + // no length or transfer-encoding means 0-length body + assert_eq!( + parse( + "\ + GET / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip, chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // content-length + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // multiple content-lengths of same value are fine + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // multiple content-lengths with different values is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 11\r\n\ + \r\n\ + ", + "multiple content-lengths", + ); + + // content-length with prefix is not allowed + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: +10\r\n\ + \r\n\ + ", + "prefixed content-length", + ); + + // transfer-encoding that isn't chunked is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + \r\n\ + ", + "transfer-encoding but not chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked, gzip\r\n\ + \r\n\ + ", + "transfer-encoding doesn't end in chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + transfer-encoding: afterlol\r\n\ + \r\n\ + ", + "transfer-encoding multiple lines doesn't end in chunked", + ); + + // http/1.0 + + assert_eq!( + parse( + "\ + POST / HTTP/1.0\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // 1.0 doesn't understand chunked, so its an error + parse_err( + "\ + POST / HTTP/1.0\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + "1.0 chunked", + ); + } + + #[test] + fn test_decoder_response() { + fn parse(s: &str) -> ParsedMessage<StatusCode> { + parse_with_method(s, Method::GET) + } + + fn parse_ignores(s: &str) { + let mut bytes = BytesMut::from(s); + assert!(Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + } + ) + .expect("parse ok") + .is_none()) + } + + fn parse_with_method(s: &str, m: Method) -> ParsedMessage<StatusCode> { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(m), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect_err("parse should err") + } + + // no content-length or transfer-encoding means close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 204 and 304 never have a body + assert_eq!( + parse( + "\ + HTTP/1.1 204 No Content\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 304 Not Modified\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // content-length + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 9\r\n\ + \r\n\ + ", + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: +8\r\n\ + \r\n\ + ", + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // transfer-encoding not-chunked is close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: yolo\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // HEAD can have content-length, but not body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ", + Method::HEAD + ) + .decode, + DecodedLength::ZERO + ); + + // CONNECT with 200 never has body + { + let msg = parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", + Method::CONNECT, + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be upgrade"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // CONNECT receiving non 200 can have a body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 400 Bad Request\r\n\ + \r\n\ + ", + Method::CONNECT + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1xx status codes + parse_ignores( + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + ", + ); + + parse_ignores( + "\ + HTTP/1.1 103 Early Hints\r\n\ + \r\n\ + ", + ); + + // 101 upgrade not supported yet + { + let msg = parse( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + ", + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be last"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // http/1.0 + assert_eq!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1.0 doesn't understand chunked + parse_err( + "\ + HTTP/1.0 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + ); + + // keep-alive + assert!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.1 keep-alive is default" + ); + + assert!( + !parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, close, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection close is always close" + ); + + assert!( + !parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.0 close is default" + ); + + assert!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, keep-alive, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection keep-alive is always keep-alive" + ); + } + + #[test] + fn test_client_request_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers.insert("*-*", HeaderValue::from_static("o_o")); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec()); + } + + #[test] + fn test_client_request_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\n\r\n" + .as_ref(), + ); + } + #[test] + fn test_client_request_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\n\r\n" + .as_ref(), + ); + } + + #[test] + fn test_server_encode_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let encoder = Server::encode( + Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert!(encoder.is_last()); + } + + #[test] + fn test_server_response_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers + .insert("weird--header", HeaderValue::from_static("")); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: application/json\r\nWeird--Header: \r\n"; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\ndate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\nDate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn parse_header_htabs() { + let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n"); + let parsed = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete"); + + assert_eq!(parsed.head.headers["server"], "hello\tworld"); + } + + #[test] + fn test_write_headers_orig_case_empty_value() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "".parse().expect("parse empty")); + let mut orig_cases = HeaderCaseMap::default(); + orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + assert_eq!( + dst, b"X-EmptY:\r\n", + "there should be no space between the colon and CRLF" + ); + } + + #[test] + fn test_write_headers_orig_case_multiple_entries() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "a".parse().unwrap()); + headers.append(&name, "b".parse().unwrap()); + + let mut orig_cases = HeaderCaseMap::default(); + orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); + orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); + } + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_incoming(b: &mut Bencher) { + let mut raw = BytesMut::from( + &b"GET /super_long_uri/and_whatever?what_should_we_talk_about/\ + I_wonder/Hard_to_write_in_an_uri_after_all/you_have_to_make\ + _up_the_punctuation_yourself/how_fun_is_that?test=foo&test1=\ + foo1&test2=foo2&test3=foo3&test4=foo4 HTTP/1.1\r\nHost: \ + hyper.rs\r\nAccept: a lot of things\r\nAccept-Charset: \ + utf8\r\nAccept-Encoding: *\r\nAccess-Control-Allow-\ + Credentials: None\r\nAccess-Control-Allow-Origin: None\r\n\ + Access-Control-Allow-Methods: None\r\nAccess-Control-Allow-\ + Headers: None\r\nContent-Encoding: utf8\r\nContent-Security-\ + Policy: None\r\nContent-Type: text/html\r\nOrigin: hyper\ + \r\nSec-Websocket-Extensions: It looks super important!\r\n\ + Sec-Websocket-Origin: hyper\r\nSec-Websocket-Version: 4.3\r\ + \nStrict-Transport-Security: None\r\nUser-Agent: hyper\r\n\ + X-Content-Duration: None\r\nX-Content-Security-Policy: None\ + \r\nX-DNSPrefetch-Control: None\r\nX-Frame-Options: \ + Something important obviously\r\nX-Requested-With: Nothing\ + \r\n\r\n"[..], + ); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_short(b: &mut Bencher) { + let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..]; + let mut raw = BytesMut::from(s); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_headers_preset(b: &mut Bencher) { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let len = 108; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut headers = HeaderMap::new(); + headers.insert("content-length", HeaderValue::from_static("10")); + headers.insert("content-type", HeaderValue::from_static("application/json")); + + b.iter(|| { + let mut vec = Vec::new(); + head.headers = headers.clone(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(vec); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_no_headers(b: &mut Bencher) { + use crate::proto::BodyLength; + + let len = 76; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut vec = Vec::with_capacity(128); + + b.iter(|| { + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(&vec); + + vec.clear(); + }) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/client.rs b/third_party/rust/hyper/src/proto/h2/client.rs new file mode 100644 index 0000000000..bac8eceb3a --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/client.rs @@ -0,0 +1,450 @@ +use std::error::Error as StdError; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use futures_channel::{mpsc, oneshot}; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use futures_util::stream::StreamExt as _; +use h2::client::{Builder, SendRequest}; +use h2::SendStream; +use http::{Method, StatusCode}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace, warn}; + +use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; +use crate::body::HttpBody; +use crate::client::dispatch::Callback; +use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; +use crate::ext::Protocol; +use crate::headers; +use crate::proto::h2::UpgradedSendStream; +use crate::proto::Dispatched; +use crate::upgrade::Upgraded; +use crate::{Body, Request, Response}; +use h2::client::ResponseFuture; + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +///// An mpsc channel is used to help notify the `Connection` task when *all* +///// other handles to it have been dropped, so that it can shutdown. +type ConnDropRef = mpsc::Sender<Never>; + +///// A oneshot channel watches the `Connection` task, and when it completes, +///// the "dispatch" task will be notified and can shutdown sooner. +type ConnEof = oneshot::Receiver<Never>; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024 * 5; // 5mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024 * 2; // 2mb +const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb +const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 1024; // 1mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_frame_size: u32, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_while_idle: bool, + pub(crate) max_concurrent_reset_streams: Option<usize>, + pub(crate) max_send_buffer_size: usize, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + #[cfg(feature = "runtime")] + keep_alive_while_idle: false, + max_concurrent_reset_streams: None, + max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE, + } + } +} + +fn new_builder(config: &Config) -> Builder { + let mut builder = Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .max_frame_size(config.max_frame_size) + .max_send_buffer_size(config.max_send_buffer_size) + .enable_push(false); + if let Some(max) = config.max_concurrent_reset_streams { + builder.max_concurrent_reset_streams(max); + } + builder +} + +fn new_ping_config(config: &Config) -> ping::Config { + ping::Config { + bdp_initial_window: if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + #[cfg(feature = "runtime")] + keep_alive_while_idle: config.keep_alive_while_idle, + } +} + +pub(crate) async fn handshake<T, B>( + io: T, + req_rx: ClientRx<B>, + config: &Config, + exec: Exec, +) -> crate::Result<ClientTask<B>> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: HttpBody, + B::Data: Send + 'static, +{ + let (h2_tx, mut conn) = new_builder(config) + .handshake::<_, SendBuf<B::Data>>(io) + .await + .map_err(crate::Error::new_h2)?; + + // An mpsc channel is used entirely to detect when the + // 'Client' has been dropped. This is to get around a bug + // in h2 where dropping all SendRequests won't notify a + // parked Connection. + let (conn_drop_ref, rx) = mpsc::channel(1); + let (cancel_tx, conn_eof) = oneshot::channel(); + + let conn_drop_rx = rx.into_future().map(|(item, _rx)| { + if let Some(never) = item { + match never {} + } + }); + + let ping_config = new_ping_config(&config); + + let (conn, ping) = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + let (recorder, mut ponger) = ping::channel(pp, ping_config); + + let conn = future::poll_fn(move |cx| { + match ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + conn.set_target_window_size(wnd); + conn.set_initial_window_size(wnd)?; + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} + } + + Pin::new(&mut conn).poll(cx) + }); + (Either::Left(conn), recorder) + } else { + (Either::Right(conn), ping::disabled()) + }; + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + + Ok(ClientTask { + ping, + conn_drop_ref, + conn_eof, + executor: exec, + h2_tx, + req_rx, + fut_ctx: None, + }) +} + +async fn conn_task<C, D>(conn: C, drop_rx: D, cancel_tx: oneshot::Sender<Never>) +where + C: Future + Unpin, + D: Future<Output = ()> + Unpin, +{ + match future::select(conn, drop_rx).await { + Either::Left(_) => { + // ok or err, the `conn` has finished + } + Either::Right(((), conn)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(cancel_tx); + let _ = conn.await; + } + } +} + +struct FutCtx<B> +where + B: HttpBody, +{ + is_connect: bool, + eos: bool, + fut: ResponseFuture, + body_tx: SendStream<SendBuf<B::Data>>, + body: B, + cb: Callback<Request<B>, Response<Body>>, +} + +impl<B: HttpBody> Unpin for FutCtx<B> {} + +pub(crate) struct ClientTask<B> +where + B: HttpBody, +{ + ping: ping::Recorder, + conn_drop_ref: ConnDropRef, + conn_eof: ConnEof, + executor: Exec, + h2_tx: SendRequest<SendBuf<B::Data>>, + req_rx: ClientRx<B>, + fut_ctx: Option<FutCtx<B>>, +} + +impl<B> ClientTask<B> +where + B: HttpBody + 'static, +{ + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.h2_tx.is_extended_connect_protocol_enabled() + } +} + +impl<B> ClientTask<B> +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut task::Context<'_>) { + let ping = self.ping.clone(); + let send_stream = if !f.is_connect { + if !f.eos { + let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| { + if let Err(e) = res { + debug!("client request body error: {}", e); + } + }); + + // eagerly see if the body pipe is ready and + // can thus skip allocating in the executor + match Pin::new(&mut pipe).poll(cx) { + Poll::Ready(_) => (), + Poll::Pending => { + let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); + let pipe = pipe.map(move |x| { + drop(conn_drop_ref); + drop(ping); + x + }); + // Clear send task + self.executor.execute(pipe); + } + } + } + + None + } else { + Some(f.body_tx) + }; + + let fut = f.fut.map(move |result| match result { + Ok(res) => { + // record that we got the response headers + ping.record_non_data(); + + let content_length = headers::content_length_parse_all(res.headers()); + if let (Some(mut send_stream), StatusCode::OK) = (send_stream, res.status()) { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect response with non-zero body not supported"); + + send_stream.send_reset(h2::Reason::INTERNAL_ERROR); + return Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + )); + } + let (parts, recv_stream) = res.into_parts(); + let mut res = Response::from_parts(parts, Body::empty()); + + let (pending, on_upgrade) = crate::upgrade::pending(); + let io = H2Upgraded { + ping, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + recv_stream, + buf: Bytes::new(), + }; + let upgraded = Upgraded::new(io, Bytes::new()); + + pending.fulfill(upgraded); + res.extensions_mut().insert(on_upgrade); + + Ok(res) + } else { + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length.into(), ping) + }); + Ok(res) + } + } + Err(err) => { + ping.ensure_not_timed_out().map_err(|e| (e, None))?; + + debug!("client response error: {}", err); + Err((crate::Error::new_h2(err), None)) + } + }); + self.executor.execute(f.cb.send_when(fut)); + } +} + +impl<B> Future for ClientTask<B> +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(self.h2_tx.poll_ready(cx)) { + Ok(()) => (), + Err(err) => { + self.ping.ensure_not_timed_out()?; + return if err.reason() == Some(::h2::Reason::NO_ERROR) { + trace!("connection gracefully shutdown"); + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Ready(Err(crate::Error::new_h2(err))) + }; + } + }; + + match self.fut_ctx.take() { + // If we were waiting on pending open + // continue where we left off. + Some(f) => { + self.poll_pipe(f, cx); + continue; + } + None => (), + } + + match self.req_rx.poll_recv(cx) { + Poll::Ready(Some((req, cb))) => { + // check that future hasn't been canceled already + if cb.is_canceled() { + trace!("request callback is canceled"); + continue; + } + let (head, body) = req.into_parts(); + let mut req = ::http::Request::from_parts(head, ()); + super::strip_connection_headers(req.headers_mut(), true); + if let Some(len) = body.size_hint().exact() { + if len != 0 || headers::method_has_defined_payload_semantics(req.method()) { + headers::set_content_length_if_missing(req.headers_mut(), len); + } + } + + let is_connect = req.method() == Method::CONNECT; + let eos = body.is_end_stream(); + + if is_connect { + if headers::content_length_parse_all(req.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 connect request with non-zero body not supported"); + cb.send(Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + ))); + continue; + } + } + + if let Some(protocol) = req.extensions_mut().remove::<Protocol>() { + req.extensions_mut().insert(protocol.into_inner()); + } + + let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) { + Ok(ok) => ok, + Err(err) => { + debug!("client send request error: {}", err); + cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + }; + + let f = FutCtx { + is_connect, + eos, + fut, + body_tx, + body, + cb, + }; + + // Check poll_ready() again. + // If the call to send_request() resulted in the new stream being pending open + // we have to wait for the open to complete before accepting new requests. + match self.h2_tx.poll_ready(cx) { + Poll::Pending => { + // Save Context + self.fut_ctx = Some(f); + return Poll::Pending; + } + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => { + f.cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + } + self.poll_pipe(f, cx); + continue; + } + + Poll::Ready(None) => { + trace!("client::dispatch::Sender dropped"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + + Poll::Pending => match ready!(Pin::new(&mut self.conn_eof).poll(cx)) { + Ok(never) => match never {}, + Err(_conn_is_eof) => { + trace!("connection task is closed, closing dispatch task"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }, + } + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/mod.rs b/third_party/rust/hyper/src/proto/h2/mod.rs new file mode 100644 index 0000000000..5857c919d1 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/mod.rs @@ -0,0 +1,471 @@ +use bytes::{Buf, Bytes}; +use h2::{Reason, RecvStream, SendStream}; +use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE}; +use http::HeaderMap; +use pin_project_lite::pin_project; +use std::error::Error as StdError; +use std::io::{self, Cursor, IoSlice}; +use std::mem; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::{debug, trace, warn}; + +use crate::body::HttpBody; +use crate::common::{task, Future, Pin, Poll}; +use crate::proto::h2::ping::Recorder; + +pub(crate) mod ping; + +cfg_client! { + pub(crate) mod client; + pub(crate) use self::client::ClientTask; +} + +cfg_server! { + pub(crate) mod server; + pub(crate) use self::server::Server; +} + +/// Default initial stream window size defined in HTTP2 spec. +pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535; + +fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { + // List of connection headers from: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + // + // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're + // tested separately. + let connection_headers = [ + HeaderName::from_lowercase(b"keep-alive").unwrap(), + HeaderName::from_lowercase(b"proxy-connection").unwrap(), + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + ]; + + for header in connection_headers.iter() { + if headers.remove(header).is_some() { + warn!("Connection header illegal in HTTP/2: {}", header.as_str()); + } + } + + if is_request { + if headers + .get(TE) + .map(|te_header| te_header != "trailers") + .unwrap_or(false) + { + warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests"); + headers.remove(TE); + } + } else if headers.remove(TE).is_some() { + warn!("TE headers illegal in HTTP/2 responses"); + } + + if let Some(header) = headers.remove(CONNECTION) { + warn!( + "Connection header illegal in HTTP/2: {}", + CONNECTION.as_str() + ); + let header_contents = header.to_str().unwrap(); + + // A `Connection` header may have a comma-separated list of names of other headers that + // are meant for only this specific connection. + // + // Iterate these names and remove them as headers. Connection-specific headers are + // forbidden in HTTP2, as that information has been moved into frame types of the h2 + // protocol. + for name in header_contents.split(',') { + let name = name.trim(); + headers.remove(name); + } + } +} + +// body adapters used by both Client and Server + +pin_project! { + struct PipeToSendStream<S> + where + S: HttpBody, + { + body_tx: SendStream<SendBuf<S::Data>>, + data_done: bool, + #[pin] + stream: S, + } +} + +impl<S> PipeToSendStream<S> +where + S: HttpBody, +{ + fn new(stream: S, tx: SendStream<SendBuf<S::Data>>) -> PipeToSendStream<S> { + PipeToSendStream { + body_tx: tx, + data_done: false, + stream, + } + } +} + +impl<S> Future for PipeToSendStream<S> +where + S: HttpBody, + S::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + if !*me.data_done { + // we don't have the next chunk of data yet, so just reserve 1 byte to make + // sure there's some capacity available. h2 will handle the capacity management + // for the actual body chunk. + me.body_tx.reserve_capacity(1); + + if me.body_tx.capacity() == 0 { + loop { + match ready!(me.body_tx.poll_capacity(cx)) { + Some(Ok(0)) => {} + Some(Ok(_)) => break, + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_body_write(e))) + } + None => { + // None means the stream is no longer in a + // streaming state, we either finished it + // somehow, or the remote reset us. + return Poll::Ready(Err(crate::Error::new_body_write( + "send stream capacity unexpectedly closed", + ))); + } + } + } + } else if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.as_mut().poll_data(cx)) { + Some(Ok(chunk)) => { + let is_eos = me.stream.is_end_stream(); + trace!( + "send body chunk: {} bytes, eos={}", + chunk.remaining(), + is_eos, + ); + + let buf = SendBuf::Buf(chunk); + me.body_tx + .send_data(buf, is_eos) + .map_err(crate::Error::new_body_write)?; + + if is_eos { + return Poll::Ready(Ok(())); + } + } + Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + None => { + me.body_tx.reserve_capacity(0); + let is_eos = me.stream.is_end_stream(); + if is_eos { + return Poll::Ready(me.body_tx.send_eos_frame()); + } else { + *me.data_done = true; + // loop again to poll_trailers + } + } + } + } else { + if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.poll_trailers(cx)) { + Ok(Some(trailers)) => { + me.body_tx + .send_trailers(trailers) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + Ok(None) => { + // There were no trailers, so send an empty DATA frame... + return Poll::Ready(me.body_tx.send_eos_frame()); + } + Err(e) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + } + } + } + } +} + +trait SendStreamExt { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>; + fn send_eos_frame(&mut self) -> crate::Result<()>; +} + +impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>, + { + let err = crate::Error::new_user_body(err); + debug!("send body user stream error: {}", err); + self.send_reset(err.h2_reason()); + err + } + + fn send_eos_frame(&mut self) -> crate::Result<()> { + trace!("send body eos"); + self.send_data(SendBuf::None, true) + .map_err(crate::Error::new_body_write) + } +} + +#[repr(usize)] +enum SendBuf<B> { + Buf(B), + Cursor(Cursor<Box<[u8]>>), + None, +} + +impl<B: Buf> Buf for SendBuf<B> { + #[inline] + fn remaining(&self) -> usize { + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => Buf::remaining(c), + Self::None => 0, + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {} + } + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded<B> +where + B: Buf, +{ + ping: Recorder, + send_stream: UpgradedSendStream<B>, + recv_stream: RecvStream, + buf: Bytes, +} + +impl<B> AsyncRead for H2Upgraded<B> +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(match e.reason() { + Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), + Some(Reason::STREAM_CLOSED) => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + } + _ => Err(h2_to_io_error(e)), + }) + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl<B> AsyncWrite for H2Upgraded<B> +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.send_stream.reserve_capacity(buf.len()); + + // We ignore all errors returned by `poll_capacity` and `write`, as we + // will get the correct from `poll_reset` anyway. + let cnt = match ready!(self.send_stream.poll_capacity(cx)) { + None => Some(0), + Some(Ok(cnt)) => self + .send_stream + .write(&buf[..cnt], false) + .ok() + .map(|()| cnt), + Some(Err(_)) => None, + }; + + if let Some(cnt) = cnt { + return Poll::Ready(Ok(cnt)); + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + if self.send_stream.write(&[], true).is_ok() { + return Poll::Ready(Ok(())) + } + + Poll::Ready(Err(h2_to_io_error( + match ready!(self.send_stream.poll_reset(cx)) { + Ok(Reason::NO_ERROR) => { + return Poll::Ready(Ok(())) + } + Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Ok(reason) => reason.into(), + Err(e) => e, + }, + ))) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + +struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>); + +impl<B> UpgradedSendStream<B> +where + B: Buf, +{ + unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self { + assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>()); + Self(mem::transmute(inner)) + } + + fn reserve_capacity(&mut self, cnt: usize) { + unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } + } + + fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> { + unsafe { self.as_inner_unchecked().poll_capacity(cx) } + } + + fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> { + unsafe { self.as_inner_unchecked().poll_reset(cx) } + } + + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + unsafe { + self.as_inner_unchecked() + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } + } + + unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> { + &mut *(&mut self.0 as *mut _ as *mut _) + } +} + +#[repr(transparent)] +struct Neutered<B> { + _inner: B, + impossible: Impossible, +} + +enum Impossible {} + +unsafe impl<B> Send for Neutered<B> {} + +impl<B> Buf for Neutered<B> { + fn remaining(&self) -> usize { + match self.impossible {} + } + + fn chunk(&self) -> &[u8] { + match self.impossible {} + } + + fn advance(&mut self, _cnt: usize) { + match self.impossible {} + } +} diff --git a/third_party/rust/hyper/src/proto/h2/ping.rs b/third_party/rust/hyper/src/proto/h2/ping.rs new file mode 100644 index 0000000000..1e8386497c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/ping.rs @@ -0,0 +1,555 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. + +#[cfg(feature = "runtime")] +use std::fmt; +#[cfg(feature = "runtime")] +use std::future::Future; +#[cfg(feature = "runtime")] +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(not(feature = "runtime"))] +use std::time::Instant; + +use h2::{Ping, PingPong}; +#[cfg(feature = "runtime")] +use tokio::time::{Instant, Sleep}; +use tracing::{debug, trace}; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + rtt: 0.0, + ping_delay: Duration::from_millis(100), + stable_count: 0, + }); + + let (bytes, next_bdp_at) = if bdp.is_some() { + (Some(0), Some(Instant::now())) + } else { + (None, None) + }; + + #[cfg(feature = "runtime")] + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + timer: Box::pin(tokio::time::sleep(interval)), + state: KeepAliveState::Init, + }); + + #[cfg(feature = "runtime")] + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + #[cfg(feature = "runtime")] + last_read_at, + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + next_bdp_at, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + #[cfg(feature = "runtime")] + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option<WindowSize>, + /// If no frames are received in this amount of time, a PING frame is sent. + #[cfg(feature = "runtime")] + pub(super) keep_alive_interval: Option<Duration>, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + #[cfg(feature = "runtime")] + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + #[cfg(feature = "runtime")] + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option<Arc<Mutex<Shared>>>, +} + +pub(super) struct Ponger { + bdp: Option<Bdp>, + #[cfg(feature = "runtime")] + keep_alive: Option<KeepAlive>, + shared: Arc<Mutex<Shared>>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option<Instant>, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option<usize>, + /// We delay a variable amount of time between BDP pings. This allows us + /// to send less pings as the bandwidth stabilizes. + next_bdp_at: Option<Instant>, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + #[cfg(feature = "runtime")] + last_read_at: Option<Instant>, + + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Round trip time in seconds + rtt: f64, + /// Delay the next ping by this amount. + /// + /// This will change depending on how stable the current bandwidth is. + ping_delay: Duration, + /// The count of ping round trips where BDP has stayed the same. + stable_count: u32, +} + +#[cfg(feature = "runtime")] +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + + state: KeepAliveState, + timer: Pin<Box<Sleep>>, +} + +#[cfg(feature = "runtime")] +enum KeepAliveState { + Init, + Scheduled, + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + #[cfg(feature = "runtime")] + KeepAliveTimedOut, +} + +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + #[cfg(feature = "runtime")] + { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } + + #[cfg(not(feature = "runtime"))] + { + self.bdp_initial_window.is_some() + } + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + #[cfg(feature = "runtime")] + locked.update_last_read_at(); + + // are we ready to send another bdp ping? + // if not, we don't need to record bytes either + + if let Some(ref next_bdp_at) = locked.next_bdp_at { + if Instant::now() < *next_bdp_at { + return; + } else { + locked.next_bdp_at = None; + } + } + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + #[cfg(feature = "runtime")] + { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + #[cfg(feature = "client")] + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + #[cfg(feature = "runtime")] + { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<Ponged> { + let now = Instant::now(); + let mut locked = self.shared.lock().unwrap(); + #[cfg(feature = "runtime")] + let is_idle = self.is_idle(); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + ka.schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let start = locked + .ping_sent_at + .expect("pong received implies ping_sent_at"); + locked.ping_sent_at = None; + let rtt = now - start; + trace!("recv pong"); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.schedule(is_idle, &locked); + } + } + + if let Some(ref mut bdp) = self.bdp { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + + let update = bdp.calculate(bytes, rtt); + locked.next_bdp_at = Some(now + bdp.ping_delay); + if let Some(update) = update { + return Poll::Ready(Ponged::SizeUpdate(update)) + } + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + } + Poll::Pending => { + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + } + } + + // XXX: this doesn't register a waker...? + Poll::Pending + } + + #[cfg(feature = "runtime")] + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + #[cfg(feature = "runtime")] + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + #[cfg(feature = "runtime")] + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option<WindowSize> { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + self.stabilize_delay(); + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.rtt == 0.0 { + // First sample means rtt is first rtt. + self.rtt = rtt; + } else { + // Weigh this rtt as 1/8 for a moving average. + self.rtt += (rtt - self.rtt) * 0.125; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + self.stabilize_delay(); + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if bytes >= self.bdp as usize * 2 / 3 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + + self.stable_count = 0; + self.ping_delay /= 2; + Some(self.bdp) + } else { + self.stabilize_delay(); + None + } + } + + fn stabilize_delay(&mut self) { + if self.ping_delay < Duration::from_secs(10) { + self.stable_count += 1; + + if self.stable_count >= 2 { + self.ping_delay *= 4; + self.stable_count = 0; + } + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +#[cfg(feature = "runtime")] +impl KeepAlive { + fn schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.as_mut().reset(interval); + } + KeepAliveState::PingSent => { + if shared.is_ping_sent() { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.as_mut().reset(interval); + } + KeepAliveState::Scheduled => (), + } + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > self.timer.deadline() { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.as_mut().reset(timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +#[cfg(feature = "runtime")] +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +#[cfg(feature = "runtime")] +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +#[cfg(feature = "runtime")] +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/server.rs b/third_party/rust/hyper/src/proto/h2/server.rs new file mode 100644 index 0000000000..d24e6bac5f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/server.rs @@ -0,0 +1,548 @@ +use std::error::Error as StdError; +use std::marker::Unpin; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use bytes::Bytes; +use h2::server::{Connection, Handshake, SendResponse}; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace, warn}; + +use super::{ping, PipeToSendStream, SendBuf}; +use crate::body::HttpBody; +use crate::common::exec::ConnStreamExec; +use crate::common::{date, task, Future, Pin, Poll}; +use crate::ext::Protocol; +use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; +use crate::proto::Dispatched; +use crate::service::HttpService; + +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; +use crate::{Body, Response}; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +// +// At the same time, a server more often has multiple clients connected, and +// so is more likely to use more resources than a client would. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb +const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 400; // 400kb +// 16 MB "sane default" taken from golang http2 +const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: u32 = 16 << 20; + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_frame_size: u32, + pub(crate) enable_connect_protocol: bool, + pub(crate) max_concurrent_streams: Option<u32>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + pub(crate) max_send_buffer_size: usize, + pub(crate) max_header_list_size: u32, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + enable_connect_protocol: false, + max_concurrent_streams: None, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE, + max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, + } + } +} + +pin_project! { + pub(crate) struct Server<T, S, B, E> + where + S: HttpService<Body>, + B: HttpBody, + { + exec: E, + service: S, + state: State<T, B>, + } +} + +enum State<T, B> +where + B: HttpBody, +{ + Handshaking { + ping_config: ping::Config, + hs: Handshake<T, SendBuf<B::Data>>, + }, + Serving(Serving<T, B>), + Closed, +} + +struct Serving<T, B> +where + B: HttpBody, +{ + ping: Option<(ping::Recorder, ping::Ponger)>, + conn: Connection<T, SendBuf<B::Data>>, + closing: Option<crate::Error>, +} + +impl<T, S, B, E> Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + E: ConnStreamExec<S::Future, B>, +{ + pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server<T, S, B, E> { + let mut builder = h2::server::Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .max_frame_size(config.max_frame_size) + .max_header_list_size(config.max_header_list_size) + .max_send_buffer_size(config.max_send_buffer_size); + if let Some(max) = config.max_concurrent_streams { + builder.max_concurrent_streams(max); + } + if config.enable_connect_protocol { + builder.enable_connect_protocol(); + } + let handshake = builder.handshake(io); + + let bdp = if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }; + + let ping_config = ping::Config { + bdp_initial_window: bdp, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + // If keep-alive is enabled for servers, always enabled while + // idle, so it can more aggressively close dead connections. + #[cfg(feature = "runtime")] + keep_alive_while_idle: true, + }; + + Server { + exec, + state: State::Handshaking { + ping_config, + hs: handshake, + }, + service, + } + } + + pub(crate) fn graceful_shutdown(&mut self) { + trace!("graceful_shutdown"); + match self.state { + State::Handshaking { .. } => { + // fall-through, to replace state with Closed + } + State::Serving(ref mut srv) => { + if srv.closing.is_none() { + srv.conn.graceful_shutdown(); + } + return; + } + State::Closed => { + return; + } + } + self.state = State::Closed; + } +} + +impl<T, S, B, E> Future for Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + E: ConnStreamExec<S::Future, B>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = &mut *self; + loop { + let next = match me.state { + State::Handshaking { + ref mut hs, + ref ping_config, + } => { + let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + Some(ping::channel(pp, ping_config.clone())) + } else { + None + }; + State::Serving(Serving { + ping, + conn, + closing: None, + }) + } + State::Serving(ref mut srv) => { + ready!(srv.poll_server(cx, &mut me.service, &mut me.exec))?; + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + State::Closed => { + // graceful_shutdown was called before handshaking finished, + // nothing to do here... + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }; + me.state = next; + } + } +} + +impl<T, B> Serving<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: HttpBody + 'static, +{ + fn poll_server<S, E>( + &mut self, + cx: &mut task::Context<'_>, + service: &mut S, + exec: &mut E, + ) -> Poll<crate::Result<()>> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, + { + if self.closing.is_none() { + loop { + self.poll_ping(cx); + + // Check that the service is ready to accept a new request. + // + // - If not, just drive the connection some. + // - If ready, try to accept a new request from the connection. + match service.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + // use `poll_closed` instead of `poll_accept`, + // in order to avoid accepting a request. + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => { + let err = crate::Error::new_user_service(err); + debug!("service closed: {}", err); + + let reason = err.h2_reason(); + if reason == Reason::NO_ERROR { + // NO_ERROR is only used for graceful shutdowns... + trace!("interpreting NO_ERROR user error as graceful_shutdown"); + self.conn.graceful_shutdown(); + } else { + trace!("abruptly shutting down with {:?}", reason); + self.conn.abrupt_shutdown(reason); + } + self.closing = Some(err); + break; + } + } + + // When the service is ready, accepts an incoming request. + match ready!(self.conn.poll_accept(cx)) { + Some(Ok((req, mut respond))) => { + trace!("incoming request"); + let content_length = headers::content_length_parse_all(req.headers()); + let ping = self + .ping + .as_ref() + .map(|ping| ping.0.clone()) + .unwrap_or_else(ping::disabled); + + // Record the headers received + ping.record_non_data(); + + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (mut req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length.into(), ping), + ), + None, + ) + } else { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect request with non-zero body not supported"); + respond.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Ok(())); + } + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::<OnUpgrade>().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some(ConnectParts { + pending, + ping, + recv_stream: stream, + }), + ) + }; + + if let Some(protocol) = req.extensions_mut().remove::<h2::ext::Protocol>() { + req.extensions_mut().insert(Protocol::from_inner(protocol)); + } + + let fut = H2Stream::new(service.call(req), connect_parts, respond); + exec.execute_h2stream(fut); + } + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + None => { + // no more incoming streams... + if let Some((ref ping, _)) = self.ping { + ping.ensure_not_timed_out()?; + } + + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + } + } + } + + debug_assert!( + self.closing.is_some(), + "poll_server broke loop without closing" + ); + + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + + Poll::Ready(Err(self.closing.take().expect("polled after error"))) + } + + fn poll_ping(&mut self, cx: &mut task::Context<'_>) { + if let Some((_, ref mut estimator)) = self.ping { + match estimator.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + self.conn.set_target_window_size(wnd); + let _ = self.conn.set_initial_window_size(wnd); + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("keep-alive timed out, closing connection"); + self.conn.abrupt_shutdown(h2::Reason::NO_ERROR); + } + Poll::Pending => {} + } + } + } +} + +pin_project! { + #[allow(missing_debug_implementations)] + pub struct H2Stream<F, B> + where + B: HttpBody, + { + reply: SendResponse<SendBuf<B::Data>>, + #[pin] + state: H2StreamState<F, B>, + } +} + +pin_project! { + #[project = H2StreamStateProj] + enum H2StreamState<F, B> + where + B: HttpBody, + { + Service { + #[pin] + fut: F, + connect_parts: Option<ConnectParts>, + }, + Body { + #[pin] + pipe: PipeToSendStream<B>, + }, + } +} + +struct ConnectParts { + pending: Pending, + ping: Recorder, + recv_stream: RecvStream, +} + +impl<F, B> H2Stream<F, B> +where + B: HttpBody, +{ + fn new( + fut: F, + connect_parts: Option<ConnectParts>, + respond: SendResponse<SendBuf<B::Data>>, + ) -> H2Stream<F, B> { + H2Stream { + reply: respond, + state: H2StreamState::Service { fut, connect_parts }, + } + } +} + +macro_rules! reply { + ($me:expr, $res:expr, $eos:expr) => {{ + match $me.reply.send_response($res, $eos) { + Ok(tx) => tx, + Err(e) => { + debug!("send response error: {}", e); + $me.reply.send_reset(Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + } + }}; +} + +impl<F, B, E> H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: HttpBody, + B::Data: 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + let mut me = self.project(); + loop { + let next = match me.state.as_mut().project() { + H2StreamStateProj::Service { + fut: h, + connect_parts, + } => { + let res = match h.poll(cx) { + Poll::Ready(Ok(r)) => r, + Poll::Pending => { + // Response is not yet ready, so we want to check if the client has sent a + // RST_STREAM frame which would cancel the current request. + if let Poll::Ready(reason) = + me.reply.poll_reset(cx).map_err(crate::Error::new_h2)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_h2(reason.into()))); + } + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let err = crate::Error::new_user_service(e); + warn!("http2 service errored: {}", err); + me.reply.send_reset(err.h2_reason()); + return Poll::Ready(Err(err)); + } + }; + + let (head, body) = res.into_parts(); + let mut res = ::http::Response::from_parts(head, ()); + super::strip_connection_headers(res.headers_mut(), false); + + // set Date header if it isn't already set... + res.headers_mut() + .entry(::http::header::DATE) + .or_insert_with(date::update_and_header_value); + + if let Some(connect_parts) = connect_parts.take() { + if res.status().is_success() { + if headers::content_length_parse_all(res.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 successful response to CONNECT request with body not supported"); + me.reply.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_user_header())); + } + let send_stream = reply!(me, res, false); + connect_parts.pending.fulfill(Upgraded::new( + H2Upgraded { + ping: connect_parts.ping, + recv_stream: connect_parts.recv_stream, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + } + + + if !body.is_end_stream() { + // automatically set Content-Length from body... + if let Some(len) = body.size_hint().exact() { + headers::set_content_length_if_missing(res.headers_mut(), len); + } + + let body_tx = reply!(me, res, false); + H2StreamState::Body { + pipe: PipeToSendStream::new(body, body_tx), + } + } else { + reply!(me, res, true); + return Poll::Ready(Ok(())); + } + } + H2StreamStateProj::Body { pipe } => { + return pipe.poll(cx); + } + }; + me.state.set(next); + } + } +} + +impl<F, B, E> Future for H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: HttpBody, + B::Data: 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll2(cx).map(|res| { + if let Err(e) = res { + debug!("stream error: {}", e); + } + }) + } +} diff --git a/third_party/rust/hyper/src/proto/mod.rs b/third_party/rust/hyper/src/proto/mod.rs new file mode 100644 index 0000000000..f938bf532b --- /dev/null +++ b/third_party/rust/hyper/src/proto/mod.rs @@ -0,0 +1,71 @@ +//! Pieces pertaining to the HTTP message protocol. + +cfg_feature! { + #![feature = "http1"] + + pub(crate) mod h1; + + pub(crate) use self::h1::Conn; + + #[cfg(feature = "client")] + pub(crate) use self::h1::dispatch; + #[cfg(feature = "server")] + pub(crate) use self::h1::ServerTransaction; +} + +#[cfg(feature = "http2")] +pub(crate) mod h2; + +/// An Incoming Message head. Includes request/status line, and headers. +#[derive(Debug, Default)] +pub(crate) struct MessageHead<S> { + /// HTTP version of the message. + pub(crate) version: http::Version, + /// Subject (request line or status line) of Incoming message. + pub(crate) subject: S, + /// Headers of the Incoming message. + pub(crate) headers: http::HeaderMap, + /// Extensions. + extensions: http::Extensions, +} + +/// An incoming request message. +#[cfg(feature = "http1")] +pub(crate) type RequestHead = MessageHead<RequestLine>; + +#[derive(Debug, Default, PartialEq)] +#[cfg(feature = "http1")] +pub(crate) struct RequestLine(pub(crate) http::Method, pub(crate) http::Uri); + +/// An incoming response message. +#[cfg(all(feature = "http1", feature = "client"))] +pub(crate) type ResponseHead = MessageHead<http::StatusCode>; + +#[derive(Debug)] +#[cfg(feature = "http1")] +pub(crate) enum BodyLength { + /// Content-Length + Known(u64), + /// Transfer-Encoding: chunked (if h1) + Unknown, +} + +/// Status of when a Disaptcher future completes. +pub(crate) enum Dispatched { + /// Dispatcher completely shutdown connection. + Shutdown, + /// Dispatcher has pending upgrade, and so did not shutdown. + #[cfg(feature = "http1")] + Upgrade(crate::upgrade::Pending), +} + +impl MessageHead<http::StatusCode> { + fn into_response<B>(self, body: B) -> http::Response<B> { + let mut res = http::Response::new(body); + *res.status_mut() = self.subject; + *res.headers_mut() = self.headers; + *res.version_mut() = self.version; + *res.extensions_mut() = self.extensions; + res + } +} diff --git a/third_party/rust/hyper/src/rt.rs b/third_party/rust/hyper/src/rt.rs new file mode 100644 index 0000000000..2614b59112 --- /dev/null +++ b/third_party/rust/hyper/src/rt.rs @@ -0,0 +1,12 @@ +//! Runtime components +//! +//! By default, hyper includes the [tokio](https://tokio.rs) runtime. +//! +//! If the `runtime` feature is disabled, the types in this module can be used +//! to plug in other runtimes. + +/// An executor of futures. +pub trait Executor<Fut> { + /// Place the future into the executor to be run. + fn execute(&self, fut: Fut); +} diff --git a/third_party/rust/hyper/src/server/accept.rs b/third_party/rust/hyper/src/server/accept.rs new file mode 100644 index 0000000000..4b7a1487dd --- /dev/null +++ b/third_party/rust/hyper/src/server/accept.rs @@ -0,0 +1,111 @@ +//! The `Accept` trait and supporting types. +//! +//! This module contains: +//! +//! - The [`Accept`](Accept) trait used to asynchronously accept incoming +//! connections. +//! - Utilities like `poll_fn` to ease creating a custom `Accept`. + +#[cfg(feature = "stream")] +use futures_core::Stream; +#[cfg(feature = "stream")] +use pin_project_lite::pin_project; + +use crate::common::{ + task::{self, Poll}, + Pin, +}; + +/// Asynchronously accept incoming connections. +pub trait Accept { + /// The connection type that can be accepted. + type Conn; + /// The error type that can occur when accepting a connection. + type Error; + + /// Poll to accept the next connection. + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>>; +} + +/// Create an `Accept` with a polling function. +/// +/// # Example +/// +/// ``` +/// use std::task::Poll; +/// use hyper::server::{accept, Server}; +/// +/// # let mock_conn = (); +/// // If we created some mocked connection... +/// let mut conn = Some(mock_conn); +/// +/// // And accept just the mocked conn once... +/// let once = accept::poll_fn(move |cx| { +/// Poll::Ready(conn.take().map(Ok::<_, ()>)) +/// }); +/// +/// let builder = Server::builder(once); +/// ``` +pub fn poll_fn<F, IO, E>(func: F) -> impl Accept<Conn = IO, Error = E> +where + F: FnMut(&mut task::Context<'_>) -> Poll<Option<Result<IO, E>>>, +{ + struct PollFn<F>(F); + + // The closure `F` is never pinned + impl<F> Unpin for PollFn<F> {} + + impl<F, IO, E> Accept for PollFn<F> + where + F: FnMut(&mut task::Context<'_>) -> Poll<Option<Result<IO, E>>>, + { + type Conn = IO; + type Error = E; + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + (self.get_mut().0)(cx) + } + } + + PollFn(func) +} + +/// Adapt a `Stream` of incoming connections into an `Accept`. +/// +/// # Optional +/// +/// This function requires enabling the `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +pub fn from_stream<S, IO, E>(stream: S) -> impl Accept<Conn = IO, Error = E> +where + S: Stream<Item = Result<IO, E>>, +{ + pin_project! { + struct FromStream<S> { + #[pin] + stream: S, + } + } + + impl<S, IO, E> Accept for FromStream<S> + where + S: Stream<Item = Result<IO, E>>, + { + type Conn = IO; + type Error = E; + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + self.project().stream.poll_next(cx) + } + } + + FromStream { stream } +} diff --git a/third_party/rust/hyper/src/server/conn.rs b/third_party/rust/hyper/src/server/conn.rs new file mode 100644 index 0000000000..d5370b0f14 --- /dev/null +++ b/third_party/rust/hyper/src/server/conn.rs @@ -0,0 +1,1045 @@ +//! Lower-level Server connection API. +//! +//! The types in this module are to provide a lower-level API based around a +//! single connection. Accepting a connection and binding it with a service +//! are not handled at this level. This module provides the building blocks to +//! customize those things externally. +//! +//! If you don't have need to manage connections yourself, consider using the +//! higher-level [Server](super) API. +//! +//! ## Example +//! A simple example that uses the `Http` struct to talk HTTP over a Tokio TCP stream +//! ```no_run +//! # #[cfg(all(feature = "http1", feature = "runtime"))] +//! # mod rt { +//! use http::{Request, Response, StatusCode}; +//! use hyper::{server::conn::Http, service::service_fn, Body}; +//! use std::{net::SocketAddr, convert::Infallible}; +//! use tokio::net::TcpListener; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { +//! let addr: SocketAddr = ([127, 0, 0, 1], 8080).into(); +//! +//! let mut tcp_listener = TcpListener::bind(addr).await?; +//! loop { +//! let (tcp_stream, _) = tcp_listener.accept().await?; +//! tokio::task::spawn(async move { +//! if let Err(http_err) = Http::new() +//! .http1_only(true) +//! .http1_keep_alive(true) +//! .serve_connection(tcp_stream, service_fn(hello)) +//! .await { +//! eprintln!("Error while serving HTTP connection: {}", http_err); +//! } +//! }); +//! } +//! } +//! +//! async fn hello(_req: Request<Body>) -> Result<Response<Body>, Infallible> { +//! Ok(Response::new(Body::from("Hello World!"))) +//! } +//! # } +//! ``` + +#[cfg(all( + any(feature = "http1", feature = "http2"), + not(all(feature = "http1", feature = "http2")) +))] +use std::marker::PhantomData; +#[cfg(all(any(feature = "http1", feature = "http2"), feature = "runtime"))] +use std::time::Duration; + +#[cfg(feature = "http2")] +use crate::common::io::Rewind; +#[cfg(all(feature = "http1", feature = "http2"))] +use crate::error::{Kind, Parse}; +#[cfg(feature = "http1")] +use crate::upgrade::Upgraded; + +cfg_feature! { + #![any(feature = "http1", feature = "http2")] + + use std::error::Error as StdError; + use std::fmt; + + use bytes::Bytes; + use pin_project_lite::pin_project; + use tokio::io::{AsyncRead, AsyncWrite}; + use tracing::trace; + + pub use super::server::Connecting; + use crate::body::{Body, HttpBody}; + use crate::common::{task, Future, Pin, Poll, Unpin}; + #[cfg(not(all(feature = "http1", feature = "http2")))] + use crate::common::Never; + use crate::common::exec::{ConnStreamExec, Exec}; + use crate::proto; + use crate::service::HttpService; + + pub(super) use self::upgrades::UpgradeableConnection; +} + +#[cfg(feature = "tcp")] +pub use super::tcp::{AddrIncoming, AddrStream}; + +/// A lower-level configuration of the HTTP protocol. +/// +/// This structure is used to configure options for an HTTP server connection. +/// +/// If you don't have need to manage connections yourself, consider using the +/// higher-level [Server](super) API. +#[derive(Clone, Debug)] +#[cfg(any(feature = "http1", feature = "http2"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +pub struct Http<E = Exec> { + pub(crate) exec: E, + h1_half_close: bool, + h1_keep_alive: bool, + h1_title_case_headers: bool, + h1_preserve_header_case: bool, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: Option<Duration>, + h1_writev: Option<bool>, + #[cfg(feature = "http2")] + h2_builder: proto::h2::server::Config, + mode: ConnectionMode, + max_buf_size: Option<usize>, + pipeline_flush: bool, +} + +/// The internal mode of HTTP protocol which indicates the behavior when a parse error occurs. +#[cfg(any(feature = "http1", feature = "http2"))] +#[derive(Clone, Debug, PartialEq)] +enum ConnectionMode { + /// Always use HTTP/1 and do not upgrade when a parse error occurs. + #[cfg(feature = "http1")] + H1Only, + /// Always use HTTP/2. + #[cfg(feature = "http2")] + H2Only, + /// Use HTTP/1 and try to upgrade to h2 when a parse error occurs. + #[cfg(all(feature = "http1", feature = "http2"))] + Fallback, +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pin_project! { + /// A future binding a connection with a Service. + /// + /// Polling this future will drive HTTP forward. + #[must_use = "futures do nothing unless polled"] + #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] + pub struct Connection<T, S, E = Exec> + where + S: HttpService<Body>, + { + pub(super) conn: Option<ProtoServer<T, S::ResBody, S, E>>, + fallback: Fallback<E>, + } +} + +#[cfg(feature = "http1")] +type Http1Dispatcher<T, B, S> = + proto::h1::Dispatcher<proto::h1::dispatch::Server<S, Body>, B, T, proto::ServerTransaction>; + +#[cfg(all(not(feature = "http1"), feature = "http2"))] +type Http1Dispatcher<T, B, S> = (Never, PhantomData<(T, Box<Pin<B>>, Box<Pin<S>>)>); + +#[cfg(feature = "http2")] +type Http2Server<T, B, S, E> = proto::h2::Server<Rewind<T>, S, B, E>; + +#[cfg(all(not(feature = "http2"), feature = "http1"))] +type Http2Server<T, B, S, E> = ( + Never, + PhantomData<(T, Box<Pin<S>>, Box<Pin<B>>, Box<Pin<E>>)>, +); + +#[cfg(any(feature = "http1", feature = "http2"))] +pin_project! { + #[project = ProtoServerProj] + pub(super) enum ProtoServer<T, B, S, E = Exec> + where + S: HttpService<Body>, + B: HttpBody, + { + H1 { + #[pin] + h1: Http1Dispatcher<T, B, S>, + }, + H2 { + #[pin] + h2: Http2Server<T, B, S, E>, + }, + } +} + +#[cfg(all(feature = "http1", feature = "http2"))] +#[derive(Clone, Debug)] +enum Fallback<E> { + ToHttp2(proto::h2::server::Config, E), + Http1Only, +} + +#[cfg(all( + any(feature = "http1", feature = "http2"), + not(all(feature = "http1", feature = "http2")) +))] +type Fallback<E> = PhantomData<E>; + +#[cfg(all(feature = "http1", feature = "http2"))] +impl<E> Fallback<E> { + fn to_h2(&self) -> bool { + match *self { + Fallback::ToHttp2(..) => true, + Fallback::Http1Only => false, + } + } +} + +#[cfg(all(feature = "http1", feature = "http2"))] +impl<E> Unpin for Fallback<E> {} + +/// Deconstructed parts of a `Connection`. +/// +/// This allows taking apart a `Connection` at a later time, in order to +/// reclaim the IO object, and additional related pieces. +#[derive(Debug)] +#[cfg(any(feature = "http1", feature = "http2"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +pub struct Parts<T, S> { + /// The original IO object used in the handshake. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// If the client sent additional bytes after its last request, and + /// this connection "ended" with an upgrade, the read buffer will contain + /// those bytes. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + /// The `Service` used to serve this connection. + pub service: S, + _inner: (), +} + +// ===== impl Http ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl Http { + /// Creates a new instance of the HTTP protocol, ready to spawn a server or + /// start accepting connections. + pub fn new() -> Http { + Http { + exec: Exec::Default, + h1_half_close: false, + h1_keep_alive: true, + h1_title_case_headers: false, + h1_preserve_header_case: false, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: None, + h1_writev: None, + #[cfg(feature = "http2")] + h2_builder: Default::default(), + mode: ConnectionMode::default(), + max_buf_size: None, + pipeline_flush: false, + } + } +} + +#[cfg(any(feature = "http1", feature = "http2"))] +impl<E> Http<E> { + /// Sets whether HTTP1 is required. + /// + /// Default is false + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_only(&mut self, val: bool) -> &mut Self { + if val { + self.mode = ConnectionMode::H1Only; + } else { + #[cfg(feature = "http2")] + { + self.mode = ConnectionMode::Fallback; + } + } + self + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_half_close(&mut self, val: bool) -> &mut Self { + self.h1_half_close = val; + self + } + + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_keep_alive(&mut self, val: bool) -> &mut Self { + self.h1_keep_alive = val; + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_title_case_headers(&mut self, enabled: bool) -> &mut Self { + self.h1_title_case_headers = enabled; + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Request`. It will also look for and use + /// such an extension in any provided `Response`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_preserve_header_case(&mut self, enabled: bool) -> &mut Self { + self.h1_preserve_header_case = enabled; + self + } + + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + #[cfg(all(feature = "http1", feature = "runtime"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "http1", feature = "runtime"))))] + pub fn http1_header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self { + self.h1_header_read_timeout = Some(read_timeout); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + #[inline] + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_writev(&mut self, val: bool) -> &mut Self { + self.h1_writev = Some(val); + self + } + + /// Sets whether HTTP2 is required. + /// + /// Default is false + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_only(&mut self, val: bool) -> &mut Self { + if val { + self.mode = ConnectionMode::H2Only; + } else { + #[cfg(feature = "http1")] + { + self.mode = ConnectionMode::Fallback; + } + } + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_stream_window_size = sz; + } + self + } + + /// Sets the max connection-level flow control for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.adaptive_window = false; + self.h2_builder.initial_conn_window_size = sz; + } + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + use proto::h2::SPEC_WINDOW_SIZE; + + self.h2_builder.adaptive_window = enabled; + if enabled { + self.h2_builder.initial_conn_window_size = SPEC_WINDOW_SIZE; + self.h2_builder.initial_stream_window_size = SPEC_WINDOW_SIZE; + } + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + if let Some(sz) = sz.into() { + self.h2_builder.max_frame_size = sz; + } + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self { + self.h2_builder.max_concurrent_streams = max.into(); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.h2_builder.max_send_buffer_size = max; + self + } + + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + #[cfg(feature = "http2")] + pub fn http2_enable_connect_protocol(&mut self) -> &mut Self { + self.h2_builder.enable_connect_protocol = true; + self + } + + /// Sets the max size of received header frames. + /// + /// Default is currently ~16MB, but may change. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_header_list_size(&mut self, max: u32) -> &mut Self { + self.h2_builder.max_header_list_size = max; + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + assert!( + max >= proto::h1::MINIMUM_MAX_BUFFER_SIZE, + "the max_buf_size cannot be smaller than the minimum that h1 specifies." + ); + self.max_buf_size = Some(max); + self + } + + /// Aggregates flushes to better support pipelined responses. + /// + /// Experimental, may have bugs. + /// + /// Default is false. + pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { + self.pipeline_flush = enabled; + self + } + + /// Set the executor used to spawn background tasks. + /// + /// Default uses implicit default (like `tokio::spawn`). + pub fn with_executor<E2>(self, exec: E2) -> Http<E2> { + Http { + exec, + h1_half_close: self.h1_half_close, + h1_keep_alive: self.h1_keep_alive, + h1_title_case_headers: self.h1_title_case_headers, + h1_preserve_header_case: self.h1_preserve_header_case, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: self.h1_header_read_timeout, + h1_writev: self.h1_writev, + #[cfg(feature = "http2")] + h2_builder: self.h2_builder, + mode: self.mode, + max_buf_size: self.max_buf_size, + pipeline_flush: self.pipeline_flush, + } + } + + /// Bind a connection together with a [`Service`](crate::service::Service). + /// + /// This returns a Future that must be polled in order for HTTP to be + /// driven on the connection. + /// + /// # Example + /// + /// ``` + /// # use hyper::{Body, Request, Response}; + /// # use hyper::service::Service; + /// # use hyper::server::conn::Http; + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # async fn run<I, S>(some_io: I, some_service: S) + /// # where + /// # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + /// # S: Service<hyper::Request<Body>, Response=hyper::Response<Body>> + Send + 'static, + /// # S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + /// # S::Future: Send, + /// # { + /// let http = Http::new(); + /// let conn = http.serve_connection(some_io, some_service); + /// + /// if let Err(e) = conn.await { + /// eprintln!("server connection error: {}", e); + /// } + /// # } + /// # fn main() {} + /// ``` + pub fn serve_connection<S, I, Bd>(&self, io: I, service: S) -> Connection<I, S, E> + where + S: HttpService<Body, ResBody = Bd>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bd: HttpBody + 'static, + Bd::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + E: ConnStreamExec<S::Future, Bd>, + { + #[cfg(feature = "http1")] + macro_rules! h1 { + () => {{ + let mut conn = proto::Conn::new(io); + if !self.h1_keep_alive { + conn.disable_keep_alive(); + } + if self.h1_half_close { + conn.set_allow_half_close(); + } + if self.h1_title_case_headers { + conn.set_title_case_headers(); + } + if self.h1_preserve_header_case { + conn.set_preserve_header_case(); + } + #[cfg(all(feature = "http1", feature = "runtime"))] + if let Some(header_read_timeout) = self.h1_header_read_timeout { + conn.set_http1_header_read_timeout(header_read_timeout); + } + if let Some(writev) = self.h1_writev { + if writev { + conn.set_write_strategy_queue(); + } else { + conn.set_write_strategy_flatten(); + } + } + conn.set_flush_pipeline(self.pipeline_flush); + if let Some(max) = self.max_buf_size { + conn.set_max_buf_size(max); + } + let sd = proto::h1::dispatch::Server::new(service); + ProtoServer::H1 { + h1: proto::h1::Dispatcher::new(sd, conn), + } + }}; + } + + let proto = match self.mode { + #[cfg(feature = "http1")] + #[cfg(not(feature = "http2"))] + ConnectionMode::H1Only => h1!(), + #[cfg(feature = "http2")] + #[cfg(feature = "http1")] + ConnectionMode::H1Only | ConnectionMode::Fallback => h1!(), + #[cfg(feature = "http2")] + ConnectionMode::H2Only => { + let rewind_io = Rewind::new(io); + let h2 = + proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone()); + ProtoServer::H2 { h2 } + } + }; + + Connection { + conn: Some(proto), + #[cfg(all(feature = "http1", feature = "http2"))] + fallback: if self.mode == ConnectionMode::Fallback { + Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone()) + } else { + Fallback::Http1Only + }, + #[cfg(not(all(feature = "http1", feature = "http2")))] + fallback: PhantomData, + } + } +} + +// ===== impl Connection ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl<I, B, S, E> Connection<I, S, E> +where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown + /// can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still + /// pending. If called after `Connection::poll` has resolved, this does + /// nothing. + pub fn graceful_shutdown(mut self: Pin<&mut Self>) { + match self.conn { + #[cfg(feature = "http1")] + Some(ProtoServer::H1 { ref mut h1, .. }) => { + h1.disable_keep_alive(); + } + #[cfg(feature = "http2")] + Some(ProtoServer::H2 { ref mut h2 }) => { + h2.graceful_shutdown(); + } + None => (), + + #[cfg(not(feature = "http1"))] + Some(ProtoServer::H1 { ref mut h1, .. }) => match h1.0 {}, + #[cfg(not(feature = "http2"))] + Some(ProtoServer::H2 { ref mut h2 }) => match h2.0 {}, + } + } + + /// Return the inner IO object, and additional information. + /// + /// If the IO object has been "rewound" the io will not contain those bytes rewound. + /// This should only be called after `poll_without_shutdown` signals + /// that the connection is "done". Otherwise, it may not have finished + /// flushing all necessary HTTP bytes. + /// + /// # Panics + /// This method will panic if this connection is using an h2 protocol. + pub fn into_parts(self) -> Parts<I, S> { + self.try_into_parts() + .unwrap_or_else(|| panic!("h2 cannot into_inner")) + } + + /// Return the inner IO object, and additional information, if available. + /// + /// This method will return a `None` if this connection is using an h2 protocol. + pub fn try_into_parts(self) -> Option<Parts<I, S>> { + match self.conn.unwrap() { + #[cfg(feature = "http1")] + ProtoServer::H1 { h1, .. } => { + let (io, read_buf, dispatch) = h1.into_inner(); + Some(Parts { + io, + read_buf, + service: dispatch.into_service(), + _inner: (), + }) + } + ProtoServer::H2 { .. } => None, + + #[cfg(not(feature = "http1"))] + ProtoServer::H1 { h1, .. } => match h1.0 {}, + } + } + + /// Poll the connection for completion, but without calling `shutdown` + /// on the underlying IO. + /// + /// This is useful to allow running a connection while doing an HTTP + /// upgrade. Once the upgrade is completed, the connection would be "done", + /// but it is not desired to actually shutdown the IO object. Instead you + /// would take it back using `into_parts`. + pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> + where + S: Unpin, + S::Future: Unpin, + B: Unpin, + { + loop { + match *self.conn.as_mut().unwrap() { + #[cfg(feature = "http1")] + ProtoServer::H1 { ref mut h1, .. } => match ready!(h1.poll_without_shutdown(cx)) { + Ok(()) => return Poll::Ready(Ok(())), + Err(e) => { + #[cfg(feature = "http2")] + match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.fallback.to_h2() => { + self.upgrade_h2(); + continue; + } + _ => (), + } + + return Poll::Ready(Err(e)); + } + }, + #[cfg(feature = "http2")] + ProtoServer::H2 { ref mut h2 } => return Pin::new(h2).poll(cx).map_ok(|_| ()), + + #[cfg(not(feature = "http1"))] + ProtoServer::H1 { ref mut h1, .. } => match h1.0 {}, + #[cfg(not(feature = "http2"))] + ProtoServer::H2 { ref mut h2 } => match h2.0 {}, + }; + } + } + + /// Prevent shutdown of the underlying IO object at the end of service the request, + /// instead run `into_parts`. This is a convenience wrapper over `poll_without_shutdown`. + /// + /// # Error + /// + /// This errors if the underlying connection protocol is not HTTP/1. + pub fn without_shutdown(self) -> impl Future<Output = crate::Result<Parts<I, S>>> + where + S: Unpin, + S::Future: Unpin, + B: Unpin, + { + let mut conn = Some(self); + futures_util::future::poll_fn(move |cx| { + ready!(conn.as_mut().unwrap().poll_without_shutdown(cx))?; + Poll::Ready(conn.take().unwrap().try_into_parts().ok_or_else(crate::Error::new_without_shutdown_not_h1)) + }) + } + + #[cfg(all(feature = "http1", feature = "http2"))] + fn upgrade_h2(&mut self) { + trace!("Trying to upgrade connection to h2"); + let conn = self.conn.take(); + + let (io, read_buf, dispatch) = match conn.unwrap() { + ProtoServer::H1 { h1, .. } => h1.into_inner(), + ProtoServer::H2 { .. } => { + panic!("h2 cannot into_inner"); + } + }; + let mut rewind_io = Rewind::new(io); + rewind_io.rewind(read_buf); + let (builder, exec) = match self.fallback { + Fallback::ToHttp2(ref builder, ref exec) => (builder, exec), + Fallback::Http1Only => unreachable!("upgrade_h2 with Fallback::Http1Only"), + }; + let h2 = proto::h2::Server::new(rewind_io, dispatch.into_service(), builder, exec.clone()); + + debug_assert!(self.conn.is_none()); + self.conn = Some(ProtoServer::H2 { h2 }); + } + + /// Enable this connection to support higher-level HTTP upgrades. + /// + /// See [the `upgrade` module](crate::upgrade) for more. + pub fn with_upgrades(self) -> UpgradeableConnection<I, S, E> + where + I: Send, + { + UpgradeableConnection { inner: self } + } +} + +#[cfg(any(feature = "http1", feature = "http2"))] +impl<I, B, S, E> Future for Connection<I, S, E> +where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin + 'static, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, +{ + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(Pin::new(self.conn.as_mut().unwrap()).poll(cx)) { + Ok(done) => { + match done { + proto::Dispatched::Shutdown => {} + #[cfg(feature = "http1")] + proto::Dispatched::Upgrade(pending) => { + // With no `Send` bound on `I`, we can't try to do + // upgrades here. In case a user was trying to use + // `Body::on_upgrade` with this API, send a special + // error letting them know about that. + pending.manual(); + } + }; + return Poll::Ready(Ok(())); + } + Err(e) => { + #[cfg(feature = "http1")] + #[cfg(feature = "http2")] + match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.fallback.to_h2() => { + self.upgrade_h2(); + continue; + } + _ => (), + } + + return Poll::Ready(Err(e)); + } + } + } + } +} + +#[cfg(any(feature = "http1", feature = "http2"))] +impl<I, S> fmt::Debug for Connection<I, S> +where + S: HttpService<Body>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").finish() + } +} + +// ===== impl ConnectionMode ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl Default for ConnectionMode { + #[cfg(all(feature = "http1", feature = "http2"))] + fn default() -> ConnectionMode { + ConnectionMode::Fallback + } + + #[cfg(all(feature = "http1", not(feature = "http2")))] + fn default() -> ConnectionMode { + ConnectionMode::H1Only + } + + #[cfg(all(not(feature = "http1"), feature = "http2"))] + fn default() -> ConnectionMode { + ConnectionMode::H2Only + } +} + +// ===== impl ProtoServer ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl<T, B, S, E> Future for ProtoServer<T, B, S, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, +{ + type Output = crate::Result<proto::Dispatched>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.project() { + #[cfg(feature = "http1")] + ProtoServerProj::H1 { h1, .. } => h1.poll(cx), + #[cfg(feature = "http2")] + ProtoServerProj::H2 { h2 } => h2.poll(cx), + + #[cfg(not(feature = "http1"))] + ProtoServerProj::H1 { h1, .. } => match h1.0 {}, + #[cfg(not(feature = "http2"))] + ProtoServerProj::H2 { h2 } => match h2.0 {}, + } + } +} + +#[cfg(any(feature = "http1", feature = "http2"))] +mod upgrades { + use super::*; + + // A future binding a connection with a Service with Upgrade support. + // + // This type is unnameable outside the crate, and so basically just an + // `impl Future`, without requiring Rust 1.26. + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct UpgradeableConnection<T, S, E> + where + S: HttpService<Body>, + { + pub(super) inner: Connection<T, S, E>, + } + + impl<I, B, S, E> UpgradeableConnection<I, S, E> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, + { + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown + /// can finish. + pub fn graceful_shutdown(mut self: Pin<&mut Self>) { + Pin::new(&mut self.inner).graceful_shutdown() + } + } + + impl<I, B, S, E> Future for UpgradeableConnection<I, S, E> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, + { + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(Pin::new(self.inner.conn.as_mut().unwrap()).poll(cx)) { + Ok(proto::Dispatched::Shutdown) => return Poll::Ready(Ok(())), + #[cfg(feature = "http1")] + Ok(proto::Dispatched::Upgrade(pending)) => { + match self.inner.conn.take() { + Some(ProtoServer::H1 { h1, .. }) => { + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(io, buf)); + return Poll::Ready(Ok(())); + } + _ => { + drop(pending); + unreachable!("Upgrade expects h1") + } + }; + } + Err(e) => { + #[cfg(feature = "http1")] + #[cfg(feature = "http2")] + match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.inner.fallback.to_h2() => { + self.inner.upgrade_h2(); + continue; + } + _ => (), + } + + return Poll::Ready(Err(e)); + } + } + } + } + } +} diff --git a/third_party/rust/hyper/src/server/mod.rs b/third_party/rust/hyper/src/server/mod.rs new file mode 100644 index 0000000000..e763d0e7c0 --- /dev/null +++ b/third_party/rust/hyper/src/server/mod.rs @@ -0,0 +1,172 @@ +//! HTTP Server +//! +//! A `Server` is created to listen on a port, parse HTTP requests, and hand +//! them off to a `Service`. +//! +//! There are two levels of APIs provide for constructing HTTP servers: +//! +//! - The higher-level [`Server`](Server) type. +//! - The lower-level [`conn`](conn) module. +//! +//! # Server +//! +//! The [`Server`](Server) is main way to start listening for HTTP requests. +//! It wraps a listener with a [`MakeService`](crate::service), and then should +//! be executed to start serving requests. +//! +//! [`Server`](Server) accepts connections in both HTTP1 and HTTP2 by default. +//! +//! ## Examples +//! +//! ```no_run +//! use std::convert::Infallible; +//! use std::net::SocketAddr; +//! use hyper::{Body, Request, Response, Server}; +//! use hyper::service::{make_service_fn, service_fn}; +//! +//! async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { +//! Ok(Response::new(Body::from("Hello World"))) +//! } +//! +//! # #[cfg(feature = "runtime")] +//! #[tokio::main] +//! async fn main() { +//! // Construct our SocketAddr to listen on... +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! +//! // And a MakeService to handle each connection... +//! let make_service = make_service_fn(|_conn| async { +//! Ok::<_, Infallible>(service_fn(handle)) +//! }); +//! +//! // Then bind and serve... +//! let server = Server::bind(&addr).serve(make_service); +//! +//! // And run forever... +//! if let Err(e) = server.await { +//! eprintln!("server error: {}", e); +//! } +//! } +//! # #[cfg(not(feature = "runtime"))] +//! # fn main() {} +//! ``` +//! +//! If you don't need the connection and your service implements `Clone` you can use +//! [`tower::make::Shared`] instead of `make_service_fn` which is a bit simpler: +//! +//! ```no_run +//! # use std::convert::Infallible; +//! # use std::net::SocketAddr; +//! # use hyper::{Body, Request, Response, Server}; +//! # use hyper::service::{make_service_fn, service_fn}; +//! # use tower::make::Shared; +//! # async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { +//! # Ok(Response::new(Body::from("Hello World"))) +//! # } +//! # #[cfg(feature = "runtime")] +//! #[tokio::main] +//! async fn main() { +//! // Construct our SocketAddr to listen on... +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! +//! // Shared is a MakeService that produces services by cloning an inner service... +//! let make_service = Shared::new(service_fn(handle)); +//! +//! // Then bind and serve... +//! let server = Server::bind(&addr).serve(make_service); +//! +//! // And run forever... +//! if let Err(e) = server.await { +//! eprintln!("server error: {}", e); +//! } +//! } +//! # #[cfg(not(feature = "runtime"))] +//! # fn main() {} +//! ``` +//! +//! Passing data to your request handler can be done like so: +//! +//! ```no_run +//! use std::convert::Infallible; +//! use std::net::SocketAddr; +//! use hyper::{Body, Request, Response, Server}; +//! use hyper::service::{make_service_fn, service_fn}; +//! # #[cfg(feature = "runtime")] +//! use hyper::server::conn::AddrStream; +//! +//! #[derive(Clone)] +//! struct AppContext { +//! // Whatever data your application needs can go here +//! } +//! +//! async fn handle( +//! context: AppContext, +//! addr: SocketAddr, +//! req: Request<Body> +//! ) -> Result<Response<Body>, Infallible> { +//! Ok(Response::new(Body::from("Hello World"))) +//! } +//! +//! # #[cfg(feature = "runtime")] +//! #[tokio::main] +//! async fn main() { +//! let context = AppContext { +//! // ... +//! }; +//! +//! // A `MakeService` that produces a `Service` to handle each connection. +//! let make_service = make_service_fn(move |conn: &AddrStream| { +//! // We have to clone the context to share it with each invocation of +//! // `make_service`. If your data doesn't implement `Clone` consider using +//! // an `std::sync::Arc`. +//! let context = context.clone(); +//! +//! // You can grab the address of the incoming connection like so. +//! let addr = conn.remote_addr(); +//! +//! // Create a `Service` for responding to the request. +//! let service = service_fn(move |req| { +//! handle(context.clone(), addr, req) +//! }); +//! +//! // Return the service to hyper. +//! async move { Ok::<_, Infallible>(service) } +//! }); +//! +//! // Run the server like above... +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! +//! let server = Server::bind(&addr).serve(make_service); +//! +//! if let Err(e) = server.await { +//! eprintln!("server error: {}", e); +//! } +//! } +//! # #[cfg(not(feature = "runtime"))] +//! # fn main() {} +//! ``` +//! +//! [`tower::make::Shared`]: https://docs.rs/tower/latest/tower/make/struct.Shared.html + +pub mod accept; +pub mod conn; +#[cfg(feature = "tcp")] +mod tcp; + +pub use self::server::Server; + +cfg_feature! { + #![any(feature = "http1", feature = "http2")] + + pub(crate) mod server; + pub use self::server::Builder; + + mod shutdown; +} + +cfg_feature! { + #![not(any(feature = "http1", feature = "http2"))] + + mod server_stub; + use server_stub as server; +} diff --git a/third_party/rust/hyper/src/server/server.rs b/third_party/rust/hyper/src/server/server.rs new file mode 100644 index 0000000000..e4273674fc --- /dev/null +++ b/third_party/rust/hyper/src/server/server.rs @@ -0,0 +1,799 @@ +use std::error::Error as StdError; +use std::fmt; +#[cfg(feature = "tcp")] +use std::net::{SocketAddr, TcpListener as StdTcpListener}; + +#[cfg(feature = "tcp")] +use std::time::Duration; + +use pin_project_lite::pin_project; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::trace; + +use super::accept::Accept; +#[cfg(all(feature = "tcp"))] +use super::tcp::AddrIncoming; +use crate::body::{Body, HttpBody}; +use crate::common::exec::Exec; +use crate::common::exec::{ConnStreamExec, NewSvcExec}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +// Renamed `Http` as `Http_` for now so that people upgrading don't see an +// error that `hyper::server::Http` is private... +use super::conn::{Connection, Http as Http_, UpgradeableConnection}; +use super::shutdown::{Graceful, GracefulWatcher}; +use crate::service::{HttpService, MakeServiceRef}; + +use self::new_svc::NewSvcTask; + +pin_project! { + /// A listening HTTP server that accepts connections in both HTTP1 and HTTP2 by default. + /// + /// `Server` is a `Future` mapping a bound listener with a set of service + /// handlers. It is built using the [`Builder`](Builder), and the future + /// completes when the server has been shutdown. It should be run by an + /// `Executor`. + pub struct Server<I, S, E = Exec> { + #[pin] + incoming: I, + make_service: S, + protocol: Http_<E>, + } +} + +/// A builder for a [`Server`](Server). +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +pub struct Builder<I, E = Exec> { + incoming: I, + protocol: Http_<E>, +} + +// ===== impl Server ===== + +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +impl<I> Server<I, ()> { + /// Starts a [`Builder`](Builder) with the provided incoming stream. + pub fn builder(incoming: I) -> Builder<I> { + Builder { + incoming, + protocol: Http_::new(), + } + } +} + +#[cfg(feature = "tcp")] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "tcp", any(feature = "http1", feature = "http2")))) +)] +impl Server<AddrIncoming, ()> { + /// Binds to the provided address, and returns a [`Builder`](Builder). + /// + /// # Panics + /// + /// This method will panic if binding to the address fails. For a method + /// to bind to an address and return a `Result`, see `Server::try_bind`. + pub fn bind(addr: &SocketAddr) -> Builder<AddrIncoming> { + let incoming = AddrIncoming::new(addr).unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }); + Server::builder(incoming) + } + + /// Tries to bind to the provided address, and returns a [`Builder`](Builder). + pub fn try_bind(addr: &SocketAddr) -> crate::Result<Builder<AddrIncoming>> { + AddrIncoming::new(addr).map(Server::builder) + } + + /// Create a new instance from a `std::net::TcpListener` instance. + pub fn from_tcp(listener: StdTcpListener) -> Result<Builder<AddrIncoming>, crate::Error> { + AddrIncoming::from_std(listener).map(Server::builder) + } +} + +#[cfg(feature = "tcp")] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "tcp", any(feature = "http1", feature = "http2")))) +)] +impl<S, E> Server<AddrIncoming, S, E> { + /// Returns the local address that this server is bound to. + pub fn local_addr(&self) -> SocketAddr { + self.incoming.local_addr() + } +} + +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +impl<I, IO, IE, S, E, B> Server<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<<S::Service as HttpService<Body>>::Future, B>, +{ + /// Prepares a server to handle graceful shutdown when the provided future + /// completes. + /// + /// # Example + /// + /// ``` + /// # fn main() {} + /// # #[cfg(feature = "tcp")] + /// # async fn run() { + /// # use hyper::{Body, Response, Server, Error}; + /// # use hyper::service::{make_service_fn, service_fn}; + /// # let make_service = make_service_fn(|_| async { + /// # Ok::<_, Error>(service_fn(|_req| async { + /// # Ok::<_, Error>(Response::new(Body::from("Hello World"))) + /// # })) + /// # }); + /// // Make a server from the previous examples... + /// let server = Server::bind(&([127, 0, 0, 1], 3000).into()) + /// .serve(make_service); + /// + /// // Prepare some signal for when the server should start shutting down... + /// let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + /// let graceful = server + /// .with_graceful_shutdown(async { + /// rx.await.ok(); + /// }); + /// + /// // Await the `server` receiving the signal... + /// if let Err(e) = graceful.await { + /// eprintln!("server error: {}", e); + /// } + /// + /// // And later, trigger the signal by calling `tx.send(())`. + /// let _ = tx.send(()); + /// # } + /// ``` + pub fn with_graceful_shutdown<F>(self, signal: F) -> Graceful<I, S, F, E> + where + F: Future<Output = ()>, + E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>, + { + Graceful::new(self, signal) + } + + fn poll_next_( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<Connecting<IO, S::Future, E>>>> { + let me = self.project(); + match ready!(me.make_service.poll_ready_ref(cx)) { + Ok(()) => (), + Err(e) => { + trace!("make_service closed"); + return Poll::Ready(Some(Err(crate::Error::new_user_make_service(e)))); + } + } + + if let Some(item) = ready!(me.incoming.poll_accept(cx)) { + let io = item.map_err(crate::Error::new_accept)?; + let new_fut = me.make_service.make_service_ref(&io); + Poll::Ready(Some(Ok(Connecting { + future: new_fut, + io: Some(io), + protocol: me.protocol.clone(), + }))) + } else { + Poll::Ready(None) + } + } + + pub(super) fn poll_watch<W>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + watcher: &W, + ) -> Poll<crate::Result<()>> + where + E: NewSvcExec<IO, S::Future, S::Service, E, W>, + W: Watcher<IO, S::Service, E>, + { + loop { + if let Some(connecting) = ready!(self.as_mut().poll_next_(cx)?) { + let fut = NewSvcTask::new(connecting, watcher.clone()); + self.as_mut().project().protocol.exec.execute_new_svc(fut); + } else { + return Poll::Ready(Ok(())); + } + } + } +} + +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +impl<I, IO, IE, S, B, E> Future for Server<I, S, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<<S::Service as HttpService<Body>>::Future, B>, + E: NewSvcExec<IO, S::Future, S::Service, E, NoopWatcher>, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_watch(cx, &NoopWatcher) + } +} + +impl<I: fmt::Debug, S: fmt::Debug> fmt::Debug for Server<I, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut st = f.debug_struct("Server"); + st.field("listener", &self.incoming); + st.finish() + } +} + +// ===== impl Builder ===== + +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +impl<I, E> Builder<I, E> { + /// Start a new builder, wrapping an incoming stream and low-level options. + /// + /// For a more convenient constructor, see [`Server::bind`](Server::bind). + pub fn new(incoming: I, protocol: Http_<E>) -> Self { + Builder { incoming, protocol } + } + + /// Sets whether to use keep-alive for HTTP/1 connections. + /// + /// Default is `true`. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_keepalive(mut self, val: bool) -> Self { + self.protocol.http1_keep_alive(val); + self + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_half_close(mut self, val: bool) -> Self { + self.protocol.http1_half_close(val); + self + } + + /// Set the maximum buffer size. + /// + /// Default is ~ 400kb. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_buf_size(mut self, val: usize) -> Self { + self.protocol.max_buf_size(val); + self + } + + // Sets whether to bunch up HTTP/1 writes until the read buffer is empty. + // + // This isn't really desirable in most cases, only really being useful in + // silly pipeline benchmarks. + #[doc(hidden)] + #[cfg(feature = "http1")] + pub fn http1_pipeline_flush(mut self, val: bool) -> Self { + self.protocol.pipeline_flush(val); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + #[cfg(feature = "http1")] + pub fn http1_writev(mut self, enabled: bool) -> Self { + self.protocol.http1_writev(enabled); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_title_case_headers(mut self, val: bool) -> Self { + self.protocol.http1_title_case_headers(val); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Request`. It will also look for and use + /// such an extension in any provided `Response`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_preserve_header_case(mut self, val: bool) -> Self { + self.protocol.http1_preserve_header_case(val); + self + } + + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + #[cfg(all(feature = "http1", feature = "runtime"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "http1", feature = "runtime"))))] + pub fn http1_header_read_timeout(mut self, read_timeout: Duration) -> Self { + self.protocol.http1_header_read_timeout(read_timeout); + self + } + + /// Sets whether HTTP/1 is required. + /// + /// Default is `false`. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_only(mut self, val: bool) -> Self { + self.protocol.http1_only(val); + self + } + + /// Sets whether HTTP/2 is required. + /// + /// Default is `false`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_only(mut self, val: bool) -> Self { + self.protocol.http2_only(val); + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_stream_window_size(mut self, sz: impl Into<Option<u32>>) -> Self { + self.protocol.http2_initial_stream_window_size(sz.into()); + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_connection_window_size(mut self, sz: impl Into<Option<u32>>) -> Self { + self.protocol + .http2_initial_connection_window_size(sz.into()); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_adaptive_window(mut self, enabled: bool) -> Self { + self.protocol.http2_adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_frame_size(mut self, sz: impl Into<Option<u32>>) -> Self { + self.protocol.http2_max_frame_size(sz); + self + } + + /// Sets the max size of received header frames. + /// + /// Default is currently ~16MB, but may change. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_header_list_size(mut self, max: u32) -> Self { + self.protocol.http2_max_header_list_size(max); + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_concurrent_streams(mut self, max: impl Into<Option<u32>>) -> Self { + self.protocol.http2_max_concurrent_streams(max.into()); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(all(feature = "runtime", feature = "http2"))] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_interval(mut self, interval: impl Into<Option<Duration>>) -> Self { + self.protocol.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(all(feature = "runtime", feature = "http2"))] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> Self { + self.protocol.http2_keep_alive_timeout(timeout); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_send_buf_size(mut self, max: usize) -> Self { + self.protocol.http2_max_send_buf_size(max); + self + } + + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + #[cfg(feature = "http2")] + pub fn http2_enable_connect_protocol(mut self) -> Self { + self.protocol.http2_enable_connect_protocol(); + self + } + + /// Sets the `Executor` to deal with connection tasks. + /// + /// Default is `tokio::spawn`. + pub fn executor<E2>(self, executor: E2) -> Builder<I, E2> { + Builder { + incoming: self.incoming, + protocol: self.protocol.with_executor(executor), + } + } + + /// Consume this `Builder`, creating a [`Server`](Server). + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tcp")] + /// # async fn run() { + /// use hyper::{Body, Error, Response, Server}; + /// use hyper::service::{make_service_fn, service_fn}; + /// + /// // Construct our SocketAddr to listen on... + /// let addr = ([127, 0, 0, 1], 3000).into(); + /// + /// // And a MakeService to handle each connection... + /// let make_svc = make_service_fn(|_| async { + /// Ok::<_, Error>(service_fn(|_req| async { + /// Ok::<_, Error>(Response::new(Body::from("Hello World"))) + /// })) + /// }); + /// + /// // Then bind and serve... + /// let server = Server::bind(&addr) + /// .serve(make_svc); + /// + /// // Run forever-ish... + /// if let Err(err) = server.await { + /// eprintln!("server error: {}", err); + /// } + /// # } + /// ``` + pub fn serve<S, B>(self, make_service: S) -> Server<I, S, E> + where + I: Accept, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<I::Conn, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: NewSvcExec<I::Conn, S::Future, S::Service, E, NoopWatcher>, + E: ConnStreamExec<<S::Service as HttpService<Body>>::Future, B>, + { + Server { + incoming: self.incoming, + make_service, + protocol: self.protocol.clone(), + } + } +} + +#[cfg(feature = "tcp")] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "tcp", any(feature = "http1", feature = "http2")))) +)] +impl<E> Builder<AddrIncoming, E> { + /// Set the duration to remain idle before sending TCP keepalive probes. + /// + /// If `None` is specified, keepalive is disabled. + pub fn tcp_keepalive(mut self, keepalive: Option<Duration>) -> Self { + self.incoming.set_keepalive(keepalive); + self + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + pub fn tcp_keepalive_interval(mut self, interval: Option<Duration>) -> Self { + self.incoming.set_keepalive_interval(interval); + self + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + pub fn tcp_keepalive_retries(mut self, retries: Option<u32>) -> Self { + self.incoming.set_keepalive_retries(retries); + self + } + + /// Set the value of `TCP_NODELAY` option for accepted connections. + pub fn tcp_nodelay(mut self, enabled: bool) -> Self { + self.incoming.set_nodelay(enabled); + self + } + + /// Set whether to sleep on accept errors. + /// + /// A possible scenario is that the process has hit the max open files + /// allowed, and so trying to accept a new connection will fail with + /// EMFILE. In some cases, it's preferable to just wait for some time, if + /// the application will likely close some files (or connections), and try + /// to accept the connection again. If this option is true, the error will + /// be logged at the error level, since it is still a big deal, and then + /// the listener will sleep for 1 second. + /// + /// In other cases, hitting the max open files should be treat similarly + /// to being out-of-memory, and simply error (and shutdown). Setting this + /// option to false will allow that. + /// + /// For more details see [`AddrIncoming::set_sleep_on_errors`] + pub fn tcp_sleep_on_accept_errors(mut self, val: bool) -> Self { + self.incoming.set_sleep_on_errors(val); + self + } +} + +// Used by `Server` to optionally watch a `Connection` future. +// +// The regular `hyper::Server` just uses a `NoopWatcher`, which does +// not need to watch anything, and so returns the `Connection` untouched. +// +// The `Server::with_graceful_shutdown` needs to keep track of all active +// connections, and signal that they start to shutdown when prompted, so +// it has a `GracefulWatcher` implementation to do that. +pub trait Watcher<I, S: HttpService<Body>, E>: Clone { + type Future: Future<Output = crate::Result<()>>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future; +} + +#[allow(missing_debug_implementations)] +#[derive(Copy, Clone)] +pub struct NoopWatcher; + +impl<I, S, E> Watcher<I, S, E> for NoopWatcher +where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: HttpService<Body>, + E: ConnStreamExec<S::Future, S::ResBody>, + S::ResBody: 'static, + <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Future = UpgradeableConnection<I, S, E>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future { + conn + } +} + +// used by exec.rs +pub(crate) mod new_svc { + use std::error::Error as StdError; + use tokio::io::{AsyncRead, AsyncWrite}; + use tracing::debug; + + use super::{Connecting, Watcher}; + use crate::body::{Body, HttpBody}; + use crate::common::exec::ConnStreamExec; + use crate::common::{task, Future, Pin, Poll, Unpin}; + use crate::service::HttpService; + use pin_project_lite::pin_project; + + // This is a `Future<Item=(), Error=()>` spawned to an `Executor` inside + // the `Server`. By being a nameable type, we can be generic over the + // user's `Service::Future`, and thus an `Executor` can execute it. + // + // Doing this allows for the server to conditionally require `Send` futures, + // depending on the `Executor` configured. + // + // Users cannot import this type, nor the associated `NewSvcExec`. Instead, + // a blanket implementation for `Executor<impl Future>` is sufficient. + + pin_project! { + #[allow(missing_debug_implementations)] + pub struct NewSvcTask<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> { + #[pin] + state: State<I, N, S, E, W>, + } + } + + pin_project! { + #[project = StateProj] + pub(super) enum State<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> { + Connecting { + #[pin] + connecting: Connecting<I, N, E>, + watcher: W, + }, + Connected { + #[pin] + future: W::Future, + }, + } + } + + impl<I, N, S: HttpService<Body>, E, W: Watcher<I, S, E>> NewSvcTask<I, N, S, E, W> { + pub(super) fn new(connecting: Connecting<I, N, E>, watcher: W) -> Self { + NewSvcTask { + state: State::Connecting { + connecting, + watcher, + }, + } + } + } + + impl<I, N, S, NE, B, E, W> Future for NewSvcTask<I, N, S, E, W> + where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + N: Future<Output = Result<S, NE>>, + NE: Into<Box<dyn StdError + Send + Sync>>, + S: HttpService<Body, ResBody = B>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, + W: Watcher<I, S, E>, + { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + // If it weren't for needing to name this type so the `Send` bounds + // could be projected to the `Serve` executor, this could just be + // an `async fn`, and much safer. Woe is me. + + let mut me = self.project(); + loop { + let next = { + match me.state.as_mut().project() { + StateProj::Connecting { + connecting, + watcher, + } => { + let res = ready!(connecting.poll(cx)); + let conn = match res { + Ok(conn) => conn, + Err(err) => { + let err = crate::Error::new_user_make_service(err); + debug!("connecting error: {}", err); + return Poll::Ready(()); + } + }; + let future = watcher.watch(conn.with_upgrades()); + State::Connected { future } + } + StateProj::Connected { future } => { + return future.poll(cx).map(|res| { + if let Err(err) = res { + debug!("connection error: {}", err); + } + }); + } + } + }; + + me.state.set(next); + } + } + } +} + +pin_project! { + /// A future building a new `Service` to a `Connection`. + /// + /// Wraps the future returned from `MakeService` into one that returns + /// a `Connection`. + #[must_use = "futures do nothing unless polled"] + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] + pub struct Connecting<I, F, E = Exec> { + #[pin] + future: F, + io: Option<I>, + protocol: Http_<E>, + } +} + +impl<I, F, S, FE, E, B> Future for Connecting<I, F, E> +where + I: AsyncRead + AsyncWrite + Unpin, + F: Future<Output = Result<S, FE>>, + S: HttpService<Body, ResBody = B>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, B>, +{ + type Output = Result<Connection<I, S, E>, FE>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + let service = ready!(me.future.poll(cx))?; + let io = Option::take(&mut me.io).expect("polled after complete"); + Poll::Ready(Ok(me.protocol.serve_connection(io, service))) + } +} diff --git a/third_party/rust/hyper/src/server/server_stub.rs b/third_party/rust/hyper/src/server/server_stub.rs new file mode 100644 index 0000000000..87b1f5131f --- /dev/null +++ b/third_party/rust/hyper/src/server/server_stub.rs @@ -0,0 +1,16 @@ +use std::fmt; + +use crate::common::exec::Exec; + +/// A listening HTTP server that accepts connections in both HTTP1 and HTTP2 by default. +/// +/// Needs at least one of the `http1` and `http2` features to be activated to actually be useful. +pub struct Server<I, S, E = Exec> { + _marker: std::marker::PhantomData<(I, S, E)>, +} + +impl<I: fmt::Debug, S: fmt::Debug> fmt::Debug for Server<I, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Server").finish() + } +} diff --git a/third_party/rust/hyper/src/server/shutdown.rs b/third_party/rust/hyper/src/server/shutdown.rs new file mode 100644 index 0000000000..96937d0827 --- /dev/null +++ b/third_party/rust/hyper/src/server/shutdown.rs @@ -0,0 +1,128 @@ +use std::error::Error as StdError; + +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::debug; + +use super::accept::Accept; +use super::conn::UpgradeableConnection; +use super::server::{Server, Watcher}; +use crate::body::{Body, HttpBody}; +use crate::common::drain::{self, Draining, Signal, Watch, Watching}; +use crate::common::exec::{ConnStreamExec, NewSvcExec}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::service::{HttpService, MakeServiceRef}; + +pin_project! { + #[allow(missing_debug_implementations)] + pub struct Graceful<I, S, F, E> { + #[pin] + state: State<I, S, F, E>, + } +} + +pin_project! { + #[project = StateProj] + pub(super) enum State<I, S, F, E> { + Running { + drain: Option<(Signal, Watch)>, + #[pin] + server: Server<I, S, E>, + #[pin] + signal: F, + }, + Draining { draining: Draining }, + } +} + +impl<I, S, F, E> Graceful<I, S, F, E> { + pub(super) fn new(server: Server<I, S, E>, signal: F) -> Self { + let drain = Some(drain::channel()); + Graceful { + state: State::Running { + drain, + server, + signal, + }, + } + } +} + +impl<I, IO, IE, S, B, F, E> Future for Graceful<I, S, F, E> +where + I: Accept<Conn = IO, Error = IE>, + IE: Into<Box<dyn StdError + Send + Sync>>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef<IO, Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: HttpBody + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + F: Future<Output = ()>, + E: ConnStreamExec<<S::Service as HttpService<Body>>::Future, B>, + E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + let next = { + match me.state.as_mut().project() { + StateProj::Running { + drain, + server, + signal, + } => match signal.poll(cx) { + Poll::Ready(()) => { + debug!("signal received, starting graceful shutdown"); + let sig = drain.take().expect("drain channel").0; + State::Draining { + draining: sig.drain(), + } + } + Poll::Pending => { + let watch = drain.as_ref().expect("drain channel").1.clone(); + return server.poll_watch(cx, &GracefulWatcher(watch)); + } + }, + StateProj::Draining { ref mut draining } => { + return Pin::new(draining).poll(cx).map(Ok); + } + } + }; + me.state.set(next); + } + } +} + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub struct GracefulWatcher(Watch); + +impl<I, S, E> Watcher<I, S, E> for GracefulWatcher +where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: HttpService<Body>, + E: ConnStreamExec<S::Future, S::ResBody>, + S::ResBody: 'static, + <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Future = + Watching<UpgradeableConnection<I, S, E>, fn(Pin<&mut UpgradeableConnection<I, S, E>>)>; + + fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future { + self.0.clone().watch(conn, on_drain) + } +} + +fn on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>) +where + S: HttpService<Body>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + S::ResBody: HttpBody + 'static, + <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, + E: ConnStreamExec<S::Future, S::ResBody>, +{ + conn.graceful_shutdown() +} diff --git a/third_party/rust/hyper/src/server/tcp.rs b/third_party/rust/hyper/src/server/tcp.rs new file mode 100644 index 0000000000..3f937154be --- /dev/null +++ b/third_party/rust/hyper/src/server/tcp.rs @@ -0,0 +1,484 @@ +use std::fmt; +use std::io; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::time::Duration; +use socket2::TcpKeepalive; + +use tokio::net::TcpListener; +use tokio::time::Sleep; +use tracing::{debug, error, trace}; + +use crate::common::{task, Future, Pin, Poll}; + +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::addr_stream::AddrStream; +use super::accept::Accept; + +#[derive(Default, Debug, Clone, Copy)] +struct TcpKeepaliveConfig { + time: Option<Duration>, + interval: Option<Duration>, + retries: Option<u32>, +} + +impl TcpKeepaliveConfig { + /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. + fn into_socket2(self) -> Option<TcpKeepalive> { + let mut dirty = false; + let mut ka = TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + dirty = true + } + if let Some(interval) = self.interval { + ka = Self::ka_with_interval(ka, interval, &mut dirty) + }; + if let Some(retries) = self.retries { + ka = Self::ka_with_retries(ka, retries, &mut dirty) + }; + if dirty { + Some(ka) + } else { + None + } + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + windows, + ))] + fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_interval(interval) + } + + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + windows, + )))] + fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive interval is not supported on this platform + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + ))] + fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_retries(retries) + } + + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + )))] + fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive retries is not supported on this platform + } +} + +/// A stream of connections from binding to an address. +#[must_use = "streams do nothing unless polled"] +pub struct AddrIncoming { + addr: SocketAddr, + listener: TcpListener, + sleep_on_errors: bool, + tcp_keepalive_config: TcpKeepaliveConfig, + tcp_nodelay: bool, + timeout: Option<Pin<Box<Sleep>>>, +} + +impl AddrIncoming { + pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> { + let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?; + + AddrIncoming::from_std(std_listener) + } + + pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> { + // TcpListener::from_std doesn't set O_NONBLOCK + std_listener + .set_nonblocking(true) + .map_err(crate::Error::new_listen)?; + let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?; + AddrIncoming::from_listener(listener) + } + + /// Creates a new `AddrIncoming` binding to provided socket address. + pub fn bind(addr: &SocketAddr) -> crate::Result<Self> { + AddrIncoming::new(addr) + } + + /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`. + pub fn from_listener(listener: TcpListener) -> crate::Result<Self> { + let addr = listener.local_addr().map_err(crate::Error::new_listen)?; + Ok(AddrIncoming { + listener, + addr, + sleep_on_errors: true, + tcp_keepalive_config: TcpKeepaliveConfig::default(), + tcp_nodelay: false, + timeout: None, + }) + } + + /// Get the local address bound to this listener. + pub fn local_addr(&self) -> SocketAddr { + self.addr + } + + /// Set the duration to remain idle before sending TCP keepalive probes. + /// + /// If `None` is specified, keepalive is disabled. + pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self { + self.tcp_keepalive_config.time = time; + self + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self { + self.tcp_keepalive_config.interval = interval; + self + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self { + self.tcp_keepalive_config.retries = retries; + self + } + + /// Set the value of `TCP_NODELAY` option for accepted connections. + pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self { + self.tcp_nodelay = enabled; + self + } + + /// Set whether to sleep on accept errors. + /// + /// A possible scenario is that the process has hit the max open files + /// allowed, and so trying to accept a new connection will fail with + /// `EMFILE`. In some cases, it's preferable to just wait for some time, if + /// the application will likely close some files (or connections), and try + /// to accept the connection again. If this option is `true`, the error + /// will be logged at the `error` level, since it is still a big deal, + /// and then the listener will sleep for 1 second. + /// + /// In other cases, hitting the max open files should be treat similarly + /// to being out-of-memory, and simply error (and shutdown). Setting + /// this option to `false` will allow that. + /// + /// Default is `true`. + pub fn set_sleep_on_errors(&mut self, val: bool) { + self.sleep_on_errors = val; + } + + fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> { + // Check if a previous timeout is active that was set by IO errors. + if let Some(ref mut to) = self.timeout { + ready!(Pin::new(to).poll(cx)); + } + self.timeout = None; + + loop { + match ready!(self.listener.poll_accept(cx)) { + Ok((socket, remote_addr)) => { + if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() { + let sock_ref = socket2::SockRef::from(&socket); + if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) { + trace!("error trying to set TCP keepalive: {}", e); + } + } + if let Err(e) = socket.set_nodelay(self.tcp_nodelay) { + trace!("error trying to set TCP nodelay: {}", e); + } + let local_addr = socket.local_addr()?; + return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr))); + } + Err(e) => { + // Connection errors can be ignored directly, continue by + // accepting the next request. + if is_connection_error(&e) { + debug!("accepted connection already errored: {}", e); + continue; + } + + if self.sleep_on_errors { + error!("accept error: {}", e); + + // Sleep 1s. + let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1))); + + match timeout.as_mut().poll(cx) { + Poll::Ready(()) => { + // Wow, it's been a second already? Ok then... + continue; + } + Poll::Pending => { + self.timeout = Some(timeout); + return Poll::Pending; + } + } + } else { + return Poll::Ready(Err(e)); + } + } + } + } + } +} + +impl Accept for AddrIncoming { + type Conn = AddrStream; + type Error = io::Error; + + fn poll_accept( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + let result = ready!(self.poll_next_(cx)); + Poll::Ready(Some(result)) + } +} + +/// This function defines errors that are per-connection. Which basically +/// means that if we get this error from `accept()` system call it means +/// next connection might be ready to be accepted. +/// +/// All other errors will incur a timeout before next `accept()` is performed. +/// The timeout is useful to handle resource exhaustion errors like ENFILE +/// and EMFILE. Otherwise, could enter into tight loop. +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} + +impl fmt::Debug for AddrIncoming { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AddrIncoming") + .field("addr", &self.addr) + .field("sleep_on_errors", &self.sleep_on_errors) + .field("tcp_keepalive_config", &self.tcp_keepalive_config) + .field("tcp_nodelay", &self.tcp_nodelay) + .finish() + } +} + +mod addr_stream { + use std::io; + use std::net::SocketAddr; + #[cfg(unix)] + use std::os::unix::io::{AsRawFd, RawFd}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::net::TcpStream; + + use crate::common::{task, Pin, Poll}; + + pin_project_lite::pin_project! { + /// A transport returned yieled by `AddrIncoming`. + #[derive(Debug)] + pub struct AddrStream { + #[pin] + inner: TcpStream, + pub(super) remote_addr: SocketAddr, + pub(super) local_addr: SocketAddr + } + } + + impl AddrStream { + pub(super) fn new( + tcp: TcpStream, + remote_addr: SocketAddr, + local_addr: SocketAddr, + ) -> AddrStream { + AddrStream { + inner: tcp, + remote_addr, + local_addr, + } + } + + /// Returns the remote (peer) address of this connection. + #[inline] + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } + + /// Returns the local address of this connection. + #[inline] + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + /// Consumes the AddrStream and returns the underlying IO object + #[inline] + pub fn into_inner(self) -> TcpStream { + self.inner + } + + /// Attempt to receive data on the socket, without removing that data + /// from the queue, registering the current task for wakeup if data is + /// not yet available. + pub fn poll_peek( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { + self.inner.poll_peek(cx, buf) + } + } + + impl AsyncRead for AddrStream { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.project().inner.poll_read(cx, buf) + } + } + + impl AsyncWrite for AddrStream { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + // TCP flush is a noop + Poll::Ready(Ok(())) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + self.project().inner.poll_shutdown(cx) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + // Note that since `self.inner` is a `TcpStream`, this could + // *probably* be hard-coded to return `true`...but it seems more + // correct to ask it anyway (maybe we're on some platform without + // scatter-gather IO?) + self.inner.is_write_vectored() + } + } + + #[cfg(unix)] + impl AsRawFd for AddrStream { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + use crate::server::tcp::TcpKeepaliveConfig; + + #[test] + fn no_tcp_keepalive_config() { + assert!(TcpKeepaliveConfig::default().into_socket2().is_none()); + } + + #[test] + fn tcp_keepalive_time_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.time = Some(Duration::from_secs(60)); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + windows, + ))] + #[test] + fn tcp_keepalive_interval_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.interval = Some(Duration::from_secs(1)); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple", + ))] + #[test] + fn tcp_keepalive_retries_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.retries = Some(3); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)")); + } else { + panic!("test failed"); + } + } +} diff --git a/third_party/rust/hyper/src/service/http.rs b/third_party/rust/hyper/src/service/http.rs new file mode 100644 index 0000000000..81a20c80b5 --- /dev/null +++ b/third_party/rust/hyper/src/service/http.rs @@ -0,0 +1,58 @@ +use std::error::Error as StdError; + +use crate::body::HttpBody; +use crate::common::{task, Future, Poll}; +use crate::{Request, Response}; + +/// An asynchronous function from `Request` to `Response`. +pub trait HttpService<ReqBody>: sealed::Sealed<ReqBody> { + /// The `HttpBody` body of the `http::Response`. + type ResBody: HttpBody; + + /// The error type that can occur within this `Service`. + /// + /// Note: Returning an `Error` to a hyper server will cause the connection + /// to be abruptly aborted. In most cases, it is better to return a `Response` + /// with a 4xx or 5xx status code. + type Error: Into<Box<dyn StdError + Send + Sync>>; + + /// The `Future` returned by this `Service`. + type Future: Future<Output = Result<Response<Self::ResBody>, Self::Error>>; + + #[doc(hidden)] + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + + #[doc(hidden)] + fn call(&mut self, req: Request<ReqBody>) -> Self::Future; +} + +impl<T, B1, B2> HttpService<B1> for T +where + T: tower_service::Service<Request<B1>, Response = Response<B2>>, + B2: HttpBody, + T::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type ResBody = B2; + + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + tower_service::Service::poll_ready(self, cx) + } + + fn call(&mut self, req: Request<B1>) -> Self::Future { + tower_service::Service::call(self, req) + } +} + +impl<T, B1, B2> sealed::Sealed<B1> for T +where + T: tower_service::Service<Request<B1>, Response = Response<B2>>, + B2: HttpBody, +{ +} + +mod sealed { + pub trait Sealed<T> {} +} diff --git a/third_party/rust/hyper/src/service/make.rs b/third_party/rust/hyper/src/service/make.rs new file mode 100644 index 0000000000..63e6f298f1 --- /dev/null +++ b/third_party/rust/hyper/src/service/make.rs @@ -0,0 +1,187 @@ +use std::error::Error as StdError; +use std::fmt; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{HttpService, Service}; +use crate::body::HttpBody; +use crate::common::{task, Future, Poll}; + +// The same "trait alias" as tower::MakeConnection, but inlined to reduce +// dependencies. +pub trait MakeConnection<Target>: self::sealed::Sealed<(Target,)> { + type Connection: AsyncRead + AsyncWrite; + type Error; + type Future: Future<Output = Result<Self::Connection, Self::Error>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + fn make_connection(&mut self, target: Target) -> Self::Future; +} + +impl<S, Target> self::sealed::Sealed<(Target,)> for S where S: Service<Target> {} + +impl<S, Target> MakeConnection<Target> for S +where + S: Service<Target>, + S::Response: AsyncRead + AsyncWrite, +{ + type Connection = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Service::poll_ready(self, cx) + } + + fn make_connection(&mut self, target: Target) -> Self::Future { + Service::call(self, target) + } +} + +// Just a sort-of "trait alias" of `MakeService`, not to be implemented +// by anyone, only used as bounds. +pub trait MakeServiceRef<Target, ReqBody>: self::sealed::Sealed<(Target, ReqBody)> { + type ResBody: HttpBody; + type Error: Into<Box<dyn StdError + Send + Sync>>; + type Service: HttpService<ReqBody, ResBody = Self::ResBody, Error = Self::Error>; + type MakeError: Into<Box<dyn StdError + Send + Sync>>; + type Future: Future<Output = Result<Self::Service, Self::MakeError>>; + + // Acting like a #[non_exhaustive] for associated types of this trait. + // + // Basically, no one outside of hyper should be able to set this type + // or declare bounds on it, so it should prevent people from creating + // trait objects or otherwise writing code that requires using *all* + // of the associated types. + // + // Why? So we can add new associated types to this alias in the future, + // if necessary. + type __DontNameMe: self::sealed::CantImpl; + + fn poll_ready_ref(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::MakeError>>; + + fn make_service_ref(&mut self, target: &Target) -> Self::Future; +} + +impl<T, Target, E, ME, S, F, IB, OB> MakeServiceRef<Target, IB> for T +where + T: for<'a> Service<&'a Target, Error = ME, Response = S, Future = F>, + E: Into<Box<dyn StdError + Send + Sync>>, + ME: Into<Box<dyn StdError + Send + Sync>>, + S: HttpService<IB, ResBody = OB, Error = E>, + F: Future<Output = Result<S, ME>>, + IB: HttpBody, + OB: HttpBody, +{ + type Error = E; + type Service = S; + type ResBody = OB; + type MakeError = ME; + type Future = F; + + type __DontNameMe = self::sealed::CantName; + + fn poll_ready_ref(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::MakeError>> { + self.poll_ready(cx) + } + + fn make_service_ref(&mut self, target: &Target) -> Self::Future { + self.call(target) + } +} + +impl<T, Target, S, B1, B2> self::sealed::Sealed<(Target, B1)> for T +where + T: for<'a> Service<&'a Target, Response = S>, + S: HttpService<B1, ResBody = B2>, + B1: HttpBody, + B2: HttpBody, +{ +} + +/// Create a `MakeService` from a function. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "runtime")] +/// # async fn run() { +/// use std::convert::Infallible; +/// use hyper::{Body, Request, Response, Server}; +/// use hyper::server::conn::AddrStream; +/// use hyper::service::{make_service_fn, service_fn}; +/// +/// let addr = ([127, 0, 0, 1], 3000).into(); +/// +/// let make_svc = make_service_fn(|socket: &AddrStream| { +/// let remote_addr = socket.remote_addr(); +/// async move { +/// Ok::<_, Infallible>(service_fn(move |_: Request<Body>| async move { +/// Ok::<_, Infallible>( +/// Response::new(Body::from(format!("Hello, {}!", remote_addr))) +/// ) +/// })) +/// } +/// }); +/// +/// // Then bind and serve... +/// let server = Server::bind(&addr) +/// .serve(make_svc); +/// +/// // Finally, spawn `server` onto an Executor... +/// if let Err(e) = server.await { +/// eprintln!("server error: {}", e); +/// } +/// # } +/// # fn main() {} +/// ``` +pub fn make_service_fn<F, Target, Ret>(f: F) -> MakeServiceFn<F> +where + F: FnMut(&Target) -> Ret, + Ret: Future, +{ + MakeServiceFn { f } +} + +/// `MakeService` returned from [`make_service_fn`] +#[derive(Clone, Copy)] +pub struct MakeServiceFn<F> { + f: F, +} + +impl<'t, F, Ret, Target, Svc, MkErr> Service<&'t Target> for MakeServiceFn<F> +where + F: FnMut(&Target) -> Ret, + Ret: Future<Output = Result<Svc, MkErr>>, + MkErr: Into<Box<dyn StdError + Send + Sync>>, +{ + type Error = MkErr; + type Response = Svc; + type Future = Ret; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, target: &'t Target) -> Self::Future { + (self.f)(target) + } +} + +impl<F> fmt::Debug for MakeServiceFn<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MakeServiceFn").finish() + } +} + +mod sealed { + pub trait Sealed<X> {} + + #[allow(unreachable_pub)] // This is intentional. + pub trait CantImpl {} + + #[allow(missing_debug_implementations)] + pub enum CantName {} + + impl CantImpl for CantName {} +} diff --git a/third_party/rust/hyper/src/service/mod.rs b/third_party/rust/hyper/src/service/mod.rs new file mode 100644 index 0000000000..22f850ca47 --- /dev/null +++ b/third_party/rust/hyper/src/service/mod.rs @@ -0,0 +1,55 @@ +//! Asynchronous Services +//! +//! A [`Service`](Service) is a trait representing an asynchronous +//! function of a request to a response. It's similar to +//! `async fn(Request) -> Result<Response, Error>`. +//! +//! The argument and return value isn't strictly required to be for HTTP. +//! Therefore, hyper uses several "trait aliases" to reduce clutter around +//! bounds. These are: +//! +//! - `HttpService`: This is blanketly implemented for all types that +//! implement `Service<http::Request<B1>, Response = http::Response<B2>>`. +//! - `MakeService`: When a `Service` returns a new `Service` as its "response", +//! we consider it a `MakeService`. Again, blanketly implemented in those cases. +//! - `MakeConnection`: A `Service` that returns a "connection", a type that +//! implements `AsyncRead` and `AsyncWrite`. +//! +//! # HttpService +//! +//! In hyper, especially in the server setting, a `Service` is usually bound +//! to a single connection. It defines how to respond to **all** requests that +//! connection will receive. +//! +//! The helper [`service_fn`](service_fn) should be sufficient for most cases, but +//! if you need to implement `Service` for a type manually, you can follow the example +//! in `service_struct_impl.rs`. +//! +//! # MakeService +//! +//! Since a `Service` is bound to a single connection, a [`Server`](crate::Server) +//! needs a way to make them as it accepts connections. This is what a +//! `MakeService` does. +//! +//! Resources that need to be shared by all `Service`s can be put into a +//! `MakeService`, and then passed to individual `Service`s when `call` +//! is called. + +pub use tower_service::Service; + +mod http; +mod make; +#[cfg(all(any(feature = "http1", feature = "http2"), feature = "client"))] +mod oneshot; +mod util; + +pub(super) use self::http::HttpService; +#[cfg(all(any(feature = "http1", feature = "http2"), feature = "client"))] +pub(super) use self::make::MakeConnection; +#[cfg(all(any(feature = "http1", feature = "http2"), feature = "server"))] +pub(super) use self::make::MakeServiceRef; +#[cfg(all(any(feature = "http1", feature = "http2"), feature = "client"))] +pub(super) use self::oneshot::{oneshot, Oneshot}; + +pub use self::make::make_service_fn; +pub use self::util::service_fn; diff --git a/third_party/rust/hyper/src/service/oneshot.rs b/third_party/rust/hyper/src/service/oneshot.rs new file mode 100644 index 0000000000..2697af8f4c --- /dev/null +++ b/third_party/rust/hyper/src/service/oneshot.rs @@ -0,0 +1,73 @@ +// TODO: Eventually to be replaced with tower_util::Oneshot. + +use pin_project_lite::pin_project; +use tower_service::Service; + +use crate::common::{task, Future, Pin, Poll}; + +pub(crate) fn oneshot<S, Req>(svc: S, req: Req) -> Oneshot<S, Req> +where + S: Service<Req>, +{ + Oneshot { + state: State::NotReady { svc, req }, + } +} + +pin_project! { + // A `Future` consuming a `Service` and request, waiting until the `Service` + // is ready, and then calling `Service::call` with the request, and + // waiting for that `Future`. + #[allow(missing_debug_implementations)] + pub struct Oneshot<S: Service<Req>, Req> { + #[pin] + state: State<S, Req>, + } +} + +pin_project! { + #[project = StateProj] + #[project_replace = StateProjOwn] + enum State<S: Service<Req>, Req> { + NotReady { + svc: S, + req: Req, + }, + Called { + #[pin] + fut: S::Future, + }, + Tmp, + } +} + +impl<S, Req> Future for Oneshot<S, Req> +where + S: Service<Req>, +{ + type Output = Result<S::Response, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + + loop { + match me.state.as_mut().project() { + StateProj::NotReady { ref mut svc, .. } => { + ready!(svc.poll_ready(cx))?; + // fallthrough out of the match's borrow + } + StateProj::Called { fut } => { + return fut.poll(cx); + } + StateProj::Tmp => unreachable!(), + } + + match me.state.as_mut().project_replace(State::Tmp) { + StateProjOwn::NotReady { mut svc, req } => { + me.state.set(State::Called { fut: svc.call(req) }); + } + _ => unreachable!(), + } + } + } +} diff --git a/third_party/rust/hyper/src/service/util.rs b/third_party/rust/hyper/src/service/util.rs new file mode 100644 index 0000000000..7cba1206f1 --- /dev/null +++ b/third_party/rust/hyper/src/service/util.rs @@ -0,0 +1,84 @@ +use std::error::Error as StdError; +use std::fmt; +use std::marker::PhantomData; + +use crate::body::HttpBody; +use crate::common::{task, Future, Poll}; +use crate::{Request, Response}; + +/// Create a `Service` from a function. +/// +/// # Example +/// +/// ``` +/// use hyper::{Body, Request, Response, Version}; +/// use hyper::service::service_fn; +/// +/// let service = service_fn(|req: Request<Body>| async move { +/// if req.version() == Version::HTTP_11 { +/// Ok(Response::new(Body::from("Hello World"))) +/// } else { +/// // Note: it's usually better to return a Response +/// // with an appropriate StatusCode instead of an Err. +/// Err("not HTTP/1.1, abort connection") +/// } +/// }); +/// ``` +pub fn service_fn<F, R, S>(f: F) -> ServiceFn<F, R> +where + F: FnMut(Request<R>) -> S, + S: Future, +{ + ServiceFn { + f, + _req: PhantomData, + } +} + +/// Service returned by [`service_fn`] +pub struct ServiceFn<F, R> { + f: F, + _req: PhantomData<fn(R)>, +} + +impl<F, ReqBody, Ret, ResBody, E> tower_service::Service<crate::Request<ReqBody>> + for ServiceFn<F, ReqBody> +where + F: FnMut(Request<ReqBody>) -> Ret, + ReqBody: HttpBody, + Ret: Future<Output = Result<Response<ResBody>, E>>, + E: Into<Box<dyn StdError + Send + Sync>>, + ResBody: HttpBody, +{ + type Response = crate::Response<ResBody>; + type Error = E; + type Future = Ret; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + (self.f)(req) + } +} + +impl<F, R> fmt::Debug for ServiceFn<F, R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("impl Service").finish() + } +} + +impl<F, R> Clone for ServiceFn<F, R> +where + F: Clone, +{ + fn clone(&self) -> Self { + ServiceFn { + f: self.f.clone(), + _req: PhantomData, + } + } +} + +impl<F, R> Copy for ServiceFn<F, R> where F: Copy {} diff --git a/third_party/rust/hyper/src/upgrade.rs b/third_party/rust/hyper/src/upgrade.rs new file mode 100644 index 0000000000..1c7b5b01cd --- /dev/null +++ b/third_party/rust/hyper/src/upgrade.rs @@ -0,0 +1,382 @@ +//! HTTP Upgrades +//! +//! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since +//! several concepts in HTTP allow for first talking HTTP, and then converting +//! to a different protocol, this module conflates them into a single API. +//! Those include: +//! +//! - HTTP/1.1 Upgrades +//! - HTTP `CONNECT` +//! +//! You are responsible for any other pre-requisites to establish an upgrade, +//! such as sending the appropriate headers, methods, and status codes. You can +//! then use [`on`][] to grab a `Future` which will resolve to the upgraded +//! connection object, or an error if the upgrade fails. +//! +//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism +//! +//! # Client +//! +//! Sending an HTTP upgrade from the [`client`](super::client) involves setting +//! either the appropriate method, if wanting to `CONNECT`, or headers such as +//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the +//! `http::Response` back, you must check for the specific information that the +//! upgrade is agreed upon by the server (such as a `101` status code), and then +//! get the `Future` from the `Response`. +//! +//! # Server +//! +//! Receiving upgrade requests in a server requires you to check the relevant +//! headers in a `Request`, and if an upgrade should be done, you then send the +//! corresponding headers in a response. To then wait for hyper to finish the +//! upgrade, you call `on()` with the `Request`, and then can spawn a task +//! awaiting it. +//! +//! # Example +//! +//! See [this example][example] showing how upgrades work with both +//! Clients and Servers. +//! +//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs + +use std::any::TypeId; +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::marker::Unpin; + +use bytes::Bytes; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::sync::oneshot; +#[cfg(any(feature = "http1", feature = "http2"))] +use tracing::trace; + +use crate::common::io::Rewind; +use crate::common::{task, Future, Pin, Poll}; + +/// An upgraded HTTP connection. +/// +/// This type holds a trait object internally of the original IO that +/// was used to speak HTTP before the upgrade. It can be used directly +/// as a `Read` or `Write` for convenience. +/// +/// Alternatively, if the exact type is known, this can be deconstructed +/// into its parts. +pub struct Upgraded { + io: Rewind<Box<dyn Io + Send>>, +} + +/// A future for a possible HTTP upgrade. +/// +/// If no upgrade was available, or it doesn't succeed, yields an `Error`. +pub struct OnUpgrade { + rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, +} + +/// The deconstructed parts of an [`Upgraded`](Upgraded) type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +/// Gets a pending HTTP upgrade from this message. +/// +/// This can be called on the following types: +/// +/// - `http::Request<B>` +/// - `http::Response<B>` +/// - `&mut http::Request<B>` +/// - `&mut http::Response<B>` +pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade { + msg.on_upgrade() +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pub(super) struct Pending { + tx: oneshot::Sender<crate::Result<Upgraded>>, +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pub(super) fn pending() -> (Pending, OnUpgrade) { + let (tx, rx) = oneshot::channel(); + (Pending { tx }, OnUpgrade { rx: Some(rx) }) +} + +// ===== impl Upgraded ===== + +impl Upgraded { + #[cfg(any(feature = "http1", feature = "http2", test))] + pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Upgraded { + io: Rewind::new_buffered(Box::new(io), read_buf), + } + } + + /// Tries to downcast the internal trait object to the type passed. + /// + /// On success, returns the downcasted parts. On error, returns the + /// `Upgraded` back. + pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { + let (io, buf) = self.io.into_inner(); + match io.__hyper_downcast() { + Ok(t) => Ok(Parts { + io: *t, + read_buf: buf, + _inner: (), + }), + Err(io) => Err(Upgraded { + io: Rewind::new_buffered(io, buf), + }), + } + } +} + +impl AsyncRead for Upgraded { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +// ===== impl OnUpgrade ===== + +impl OnUpgrade { + pub(super) fn none() -> Self { + OnUpgrade { rx: None } + } + + #[cfg(feature = "http1")] + pub(super) fn is_none(&self) -> bool { + self.rx.is_none() + } +} + +impl Future for OnUpgrade { + type Output = Result<Upgraded, crate::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.rx { + Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(upgraded)) => Ok(upgraded), + Ok(Err(err)) => Err(err), + Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)), + }), + None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), + } + } +} + +impl fmt::Debug for OnUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnUpgrade").finish() + } +} + +// ===== impl Pending ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl Pending { + pub(super) fn fulfill(self, upgraded: Upgraded) { + trace!("pending upgrade fulfill"); + let _ = self.tx.send(Ok(upgraded)); + } + + #[cfg(feature = "http1")] + /// Don't fulfill the pending Upgrade, but instead signal that + /// upgrades are handled manually. + pub(super) fn manual(self) { + trace!("pending upgrade handled manually"); + let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); + } +} + +// ===== impl UpgradeExpected ===== + +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected; + +impl fmt::Display for UpgradeExpected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("upgrade expected but not completed") + } +} + +impl StdError for UpgradeExpected {} + +// ===== impl Io ===== + +pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::<Self>() + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {} + +impl dyn Io + Send { + fn __hyper_is<T: Io>(&self) -> bool { + let t = TypeId::of::<T>(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { + if self.__hyper_is::<T>() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut dyn Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +mod sealed { + use super::OnUpgrade; + + pub trait CanUpgrade { + fn on_upgrade(self) -> OnUpgrade; + } + + impl<B> CanUpgrade for http::Request<B> { + fn on_upgrade(mut self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for &'_ mut http::Request<B> { + fn on_upgrade(self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for http::Response<B> { + fn on_upgrade(mut self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for &'_ mut http::Response<B> { + fn on_upgrade(self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn upgraded_downcast() { + let upgraded = Upgraded::new(Mock, Bytes::new()); + + let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); + + upgraded.downcast::<Mock>().unwrap(); + } + + // TODO: replace with tokio_test::io when it can test write_buf + struct Mock; + + impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_read") + } + } + + impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + // panic!("poll_write shouldn't be called"); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_flush") + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_shutdown") + } + } +} |