summaryrefslogtreecommitdiffstats
path: root/third_party/rust/hyper/src/proto
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/hyper/src/proto
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/hyper/src/proto')
-rw-r--r--third_party/rust/hyper/src/proto/h1/conn.rs1425
-rw-r--r--third_party/rust/hyper/src/proto/h1/decode.rs731
-rw-r--r--third_party/rust/hyper/src/proto/h1/dispatch.rs750
-rw-r--r--third_party/rust/hyper/src/proto/h1/encode.rs439
-rw-r--r--third_party/rust/hyper/src/proto/h1/io.rs1002
-rw-r--r--third_party/rust/hyper/src/proto/h1/mod.rs122
-rw-r--r--third_party/rust/hyper/src/proto/h1/role.rs2847
-rw-r--r--third_party/rust/hyper/src/proto/h2/client.rs450
-rw-r--r--third_party/rust/hyper/src/proto/h2/mod.rs471
-rw-r--r--third_party/rust/hyper/src/proto/h2/ping.rs555
-rw-r--r--third_party/rust/hyper/src/proto/h2/server.rs548
-rw-r--r--third_party/rust/hyper/src/proto/mod.rs71
12 files changed, 9411 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/proto/h1/conn.rs b/third_party/rust/hyper/src/proto/h1/conn.rs
new file mode 100644
index 0000000000..5ebff2803e
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/conn.rs
@@ -0,0 +1,1425 @@
+use std::fmt;
+use std::io;
+use std::marker::PhantomData;
+#[cfg(all(feature = "server", feature = "runtime"))]
+use std::time::Duration;
+
+use bytes::{Buf, Bytes};
+use http::header::{HeaderValue, CONNECTION};
+use http::{HeaderMap, Method, Version};
+use httparse::ParserConfig;
+use tokio::io::{AsyncRead, AsyncWrite};
+#[cfg(all(feature = "server", feature = "runtime"))]
+use tokio::time::Sleep;
+use tracing::{debug, error, trace};
+
+use super::io::Buffered;
+use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants};
+use crate::body::DecodedLength;
+use crate::common::{task, Pin, Poll, Unpin};
+use crate::headers::connection_keep_alive;
+use crate::proto::{BodyLength, MessageHead};
+
+const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
+
+/// This handles a connection, which will have been established over an
+/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple
+/// `Transaction`s over HTTP.
+///
+/// The connection will determine when a message begins and ends as well as
+/// determine if this connection can be kept alive after the message,
+/// or if it is complete.
+pub(crate) struct Conn<I, B, T> {
+ io: Buffered<I, EncodedBuf<B>>,
+ state: State,
+ _marker: PhantomData<fn(T)>,
+}
+
+impl<I, B, T> Conn<I, B, T>
+where
+ I: AsyncRead + AsyncWrite + Unpin,
+ B: Buf,
+ T: Http1Transaction,
+{
+ pub(crate) fn new(io: I) -> Conn<I, B, T> {
+ Conn {
+ io: Buffered::new(io),
+ state: State {
+ allow_half_close: false,
+ cached_headers: None,
+ error: None,
+ keep_alive: KA::Busy,
+ method: None,
+ h1_parser_config: ParserConfig::default(),
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout: None,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_fut: None,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_running: false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ title_case_headers: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ notify_read: false,
+ reading: Reading::Init,
+ writing: Writing::Init,
+ upgrade: None,
+ // We assume a modern world where the remote speaks HTTP/1.1.
+ // If they tell us otherwise, we'll downgrade in `read_head`.
+ version: Version::HTTP_11,
+ },
+ _marker: PhantomData,
+ }
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
+ self.io.set_flush_pipeline(enabled);
+ }
+
+ pub(crate) fn set_write_strategy_queue(&mut self) {
+ self.io.set_write_strategy_queue();
+ }
+
+ pub(crate) fn set_max_buf_size(&mut self, max: usize) {
+ self.io.set_max_buf_size(max);
+ }
+
+ #[cfg(feature = "client")]
+ pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
+ self.io.set_read_buf_exact_size(sz);
+ }
+
+ pub(crate) fn set_write_strategy_flatten(&mut self) {
+ self.io.set_write_strategy_flatten();
+ }
+
+ #[cfg(feature = "client")]
+ pub(crate) fn set_h1_parser_config(&mut self, parser_config: ParserConfig) {
+ self.state.h1_parser_config = parser_config;
+ }
+
+ pub(crate) fn set_title_case_headers(&mut self) {
+ self.state.title_case_headers = true;
+ }
+
+ pub(crate) fn set_preserve_header_case(&mut self) {
+ self.state.preserve_header_case = true;
+ }
+
+ #[cfg(feature = "ffi")]
+ pub(crate) fn set_preserve_header_order(&mut self) {
+ self.state.preserve_header_order = true;
+ }
+
+ #[cfg(feature = "client")]
+ pub(crate) fn set_h09_responses(&mut self) {
+ self.state.h09_responses = true;
+ }
+
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) {
+ self.state.h1_header_read_timeout = Some(val);
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn set_allow_half_close(&mut self) {
+ self.state.allow_half_close = true;
+ }
+
+ #[cfg(feature = "ffi")]
+ pub(crate) fn set_raw_headers(&mut self, enabled: bool) {
+ self.state.raw_headers = enabled;
+ }
+
+ pub(crate) fn into_inner(self) -> (I, Bytes) {
+ self.io.into_inner()
+ }
+
+ pub(crate) fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> {
+ self.state.upgrade.take()
+ }
+
+ pub(crate) fn is_read_closed(&self) -> bool {
+ self.state.is_read_closed()
+ }
+
+ pub(crate) fn is_write_closed(&self) -> bool {
+ self.state.is_write_closed()
+ }
+
+ pub(crate) fn can_read_head(&self) -> bool {
+ if !matches!(self.state.reading, Reading::Init) {
+ return false;
+ }
+
+ if T::should_read_first() {
+ return true;
+ }
+
+ !matches!(self.state.writing, Writing::Init)
+ }
+
+ pub(crate) fn can_read_body(&self) -> bool {
+ match self.state.reading {
+ Reading::Body(..) | Reading::Continue(..) => true,
+ _ => false,
+ }
+ }
+
+ fn should_error_on_eof(&self) -> bool {
+ // If we're idle, it's probably just the connection closing gracefully.
+ T::should_error_on_parse_eof() && !self.state.is_idle()
+ }
+
+ fn has_h2_prefix(&self) -> bool {
+ let read_buf = self.io.read_buf();
+ read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE
+ }
+
+ pub(super) fn poll_read_head(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> {
+ debug_assert!(self.can_read_head());
+ trace!("Conn::read_head");
+
+ let msg = match ready!(self.io.parse::<T>(
+ cx,
+ ParseContext {
+ cached_headers: &mut self.state.cached_headers,
+ req_method: &mut self.state.method,
+ h1_parser_config: self.state.h1_parser_config.clone(),
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout: self.state.h1_header_read_timeout,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running,
+ preserve_header_case: self.state.preserve_header_case,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: self.state.preserve_header_order,
+ h09_responses: self.state.h09_responses,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut self.state.on_informational,
+ #[cfg(feature = "ffi")]
+ raw_headers: self.state.raw_headers,
+ }
+ )) {
+ Ok(msg) => msg,
+ Err(e) => return self.on_read_head_error(e),
+ };
+
+ // Note: don't deconstruct `msg` into local variables, it appears
+ // the optimizer doesn't remove the extra copies.
+
+ debug!("incoming body is {}", msg.decode);
+
+ // Prevent accepting HTTP/0.9 responses after the initial one, if any.
+ self.state.h09_responses = false;
+
+ // Drop any OnInformational callbacks, we're done there!
+ #[cfg(feature = "ffi")]
+ {
+ self.state.on_informational = None;
+ }
+
+ self.state.busy();
+ self.state.keep_alive &= msg.keep_alive;
+ self.state.version = msg.head.version;
+
+ let mut wants = if msg.wants_upgrade {
+ Wants::UPGRADE
+ } else {
+ Wants::EMPTY
+ };
+
+ if msg.decode == DecodedLength::ZERO {
+ if msg.expect_continue {
+ debug!("ignoring expect-continue since body is empty");
+ }
+ self.state.reading = Reading::KeepAlive;
+ if !T::should_read_first() {
+ self.try_keep_alive(cx);
+ }
+ } else if msg.expect_continue {
+ self.state.reading = Reading::Continue(Decoder::new(msg.decode));
+ wants = wants.add(Wants::EXPECT);
+ } else {
+ self.state.reading = Reading::Body(Decoder::new(msg.decode));
+ }
+
+ Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
+ }
+
+ fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> {
+ // If we are currently waiting on a message, then an empty
+ // message should be reported as an error. If not, it is just
+ // the connection closing gracefully.
+ let must_error = self.should_error_on_eof();
+ self.close_read();
+ self.io.consume_leading_lines();
+ let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty();
+ if was_mid_parse || must_error {
+ // We check if the buf contains the h2 Preface
+ debug!(
+ "parse error ({}) with {} bytes",
+ e,
+ self.io.read_buf().len()
+ );
+ match self.on_parse_error(e) {
+ Ok(()) => Poll::Pending, // XXX: wat?
+ Err(e) => Poll::Ready(Some(Err(e))),
+ }
+ } else {
+ debug!("read eof");
+ self.close_write();
+ Poll::Ready(None)
+ }
+ }
+
+ pub(crate) fn poll_read_body(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<Option<io::Result<Bytes>>> {
+ debug_assert!(self.can_read_body());
+
+ let (reading, ret) = match self.state.reading {
+ Reading::Body(ref mut decoder) => {
+ match ready!(decoder.decode(cx, &mut self.io)) {
+ Ok(slice) => {
+ let (reading, chunk) = if decoder.is_eof() {
+ debug!("incoming body completed");
+ (
+ Reading::KeepAlive,
+ if !slice.is_empty() {
+ Some(Ok(slice))
+ } else {
+ None
+ },
+ )
+ } else if slice.is_empty() {
+ error!("incoming body unexpectedly ended");
+ // This should be unreachable, since all 3 decoders
+ // either set eof=true or return an Err when reading
+ // an empty slice...
+ (Reading::Closed, None)
+ } else {
+ return Poll::Ready(Some(Ok(slice)));
+ };
+ (reading, Poll::Ready(chunk))
+ }
+ Err(e) => {
+ debug!("incoming body decode error: {}", e);
+ (Reading::Closed, Poll::Ready(Some(Err(e))))
+ }
+ }
+ }
+ Reading::Continue(ref decoder) => {
+ // Write the 100 Continue if not already responded...
+ if let Writing::Init = self.state.writing {
+ trace!("automatically sending 100 Continue");
+ let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
+ self.io.headers_buf().extend_from_slice(cont);
+ }
+
+ // And now recurse once in the Reading::Body state...
+ self.state.reading = Reading::Body(decoder.clone());
+ return self.poll_read_body(cx);
+ }
+ _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading),
+ };
+
+ self.state.reading = reading;
+ self.try_keep_alive(cx);
+ ret
+ }
+
+ pub(crate) fn wants_read_again(&mut self) -> bool {
+ let ret = self.state.notify_read;
+ self.state.notify_read = false;
+ ret
+ }
+
+ pub(crate) fn poll_read_keep_alive(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<crate::Result<()>> {
+ debug_assert!(!self.can_read_head() && !self.can_read_body());
+
+ if self.is_read_closed() {
+ Poll::Pending
+ } else if self.is_mid_message() {
+ self.mid_message_detect_eof(cx)
+ } else {
+ self.require_empty_read(cx)
+ }
+ }
+
+ fn is_mid_message(&self) -> bool {
+ !matches!(
+ (&self.state.reading, &self.state.writing),
+ (&Reading::Init, &Writing::Init)
+ )
+ }
+
+ // This will check to make sure the io object read is empty.
+ //
+ // This should only be called for Clients wanting to enter the idle
+ // state.
+ fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed());
+ debug_assert!(!self.is_mid_message());
+ debug_assert!(T::is_client());
+
+ if !self.io.read_buf().is_empty() {
+ debug!("received an unexpected {} bytes", self.io.read_buf().len());
+ return Poll::Ready(Err(crate::Error::new_unexpected_message()));
+ }
+
+ let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?;
+
+ if num_read == 0 {
+ let ret = if self.should_error_on_eof() {
+ trace!("found unexpected EOF on busy connection: {:?}", self.state);
+ Poll::Ready(Err(crate::Error::new_incomplete()))
+ } else {
+ trace!("found EOF on idle connection, closing");
+ Poll::Ready(Ok(()))
+ };
+
+ // order is important: should_error needs state BEFORE close_read
+ self.state.close_read();
+ return ret;
+ }
+
+ debug!(
+ "received unexpected {} bytes on an idle connection",
+ num_read
+ );
+ Poll::Ready(Err(crate::Error::new_unexpected_message()))
+ }
+
+ fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed());
+ debug_assert!(self.is_mid_message());
+
+ if self.state.allow_half_close || !self.io.read_buf().is_empty() {
+ return Poll::Pending;
+ }
+
+ let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?;
+
+ if num_read == 0 {
+ trace!("found unexpected EOF on busy connection: {:?}", self.state);
+ self.state.close_read();
+ Poll::Ready(Err(crate::Error::new_incomplete()))
+ } else {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> {
+ debug_assert!(!self.state.is_read_closed());
+
+ let result = ready!(self.io.poll_read_from_io(cx));
+ Poll::Ready(result.map_err(|e| {
+ trace!("force_io_read; io error = {:?}", e);
+ self.state.close();
+ e
+ }))
+ }
+
+ fn maybe_notify(&mut self, cx: &mut task::Context<'_>) {
+ // its possible that we returned NotReady from poll() without having
+ // exhausted the underlying Io. We would have done this when we
+ // determined we couldn't keep reading until we knew how writing
+ // would finish.
+
+ match self.state.reading {
+ Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => {
+ return
+ }
+ Reading::Init => (),
+ };
+
+ match self.state.writing {
+ Writing::Body(..) => return,
+ Writing::Init | Writing::KeepAlive | Writing::Closed => (),
+ }
+
+ if !self.io.is_read_blocked() {
+ if self.io.read_buf().is_empty() {
+ match self.io.poll_read_from_io(cx) {
+ Poll::Ready(Ok(n)) => {
+ if n == 0 {
+ trace!("maybe_notify; read eof");
+ if self.state.is_idle() {
+ self.state.close();
+ } else {
+ self.close_read()
+ }
+ return;
+ }
+ }
+ Poll::Pending => {
+ trace!("maybe_notify; read_from_io blocked");
+ return;
+ }
+ Poll::Ready(Err(e)) => {
+ trace!("maybe_notify; read_from_io error: {}", e);
+ self.state.close();
+ self.state.error = Some(crate::Error::new_io(e));
+ }
+ }
+ }
+ self.state.notify_read = true;
+ }
+ }
+
+ fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) {
+ self.state.try_keep_alive::<T>();
+ self.maybe_notify(cx);
+ }
+
+ pub(crate) fn can_write_head(&self) -> bool {
+ if !T::should_read_first() && matches!(self.state.reading, Reading::Closed) {
+ return false;
+ }
+
+ match self.state.writing {
+ Writing::Init => self.io.can_headers_buf(),
+ _ => false,
+ }
+ }
+
+ pub(crate) fn can_write_body(&self) -> bool {
+ match self.state.writing {
+ Writing::Body(..) => true,
+ Writing::Init | Writing::KeepAlive | Writing::Closed => false,
+ }
+ }
+
+ pub(crate) fn can_buffer_body(&self) -> bool {
+ self.io.can_buffer()
+ }
+
+ pub(crate) fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) {
+ if let Some(encoder) = self.encode_head(head, body) {
+ self.state.writing = if !encoder.is_eof() {
+ Writing::Body(encoder)
+ } else if encoder.is_last() {
+ Writing::Closed
+ } else {
+ Writing::KeepAlive
+ };
+ }
+ }
+
+ pub(crate) fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) {
+ if let Some(encoder) =
+ self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64)))
+ {
+ let is_last = encoder.is_last();
+ // Make sure we don't write a body if we weren't actually allowed
+ // to do so, like because its a HEAD request.
+ if !encoder.is_eof() {
+ encoder.danger_full_buf(body, self.io.write_buf());
+ }
+ self.state.writing = if is_last {
+ Writing::Closed
+ } else {
+ Writing::KeepAlive
+ }
+ }
+ }
+
+ fn encode_head(
+ &mut self,
+ mut head: MessageHead<T::Outgoing>,
+ body: Option<BodyLength>,
+ ) -> Option<Encoder> {
+ debug_assert!(self.can_write_head());
+
+ if !T::should_read_first() {
+ self.state.busy();
+ }
+
+ self.enforce_version(&mut head);
+
+ let buf = self.io.headers_buf();
+ match super::role::encode_headers::<T>(
+ Encode {
+ head: &mut head,
+ body,
+ #[cfg(feature = "server")]
+ keep_alive: self.state.wants_keep_alive(),
+ req_method: &mut self.state.method,
+ title_case_headers: self.state.title_case_headers,
+ },
+ buf,
+ ) {
+ Ok(encoder) => {
+ debug_assert!(self.state.cached_headers.is_none());
+ debug_assert!(head.headers.is_empty());
+ self.state.cached_headers = Some(head.headers);
+
+ #[cfg(feature = "ffi")]
+ {
+ self.state.on_informational =
+ head.extensions.remove::<crate::ffi::OnInformational>();
+ }
+
+ Some(encoder)
+ }
+ Err(err) => {
+ self.state.error = Some(err);
+ self.state.writing = Writing::Closed;
+ None
+ }
+ }
+ }
+
+ // Fix keep-alive when Connection: keep-alive header is not present
+ fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) {
+ let outgoing_is_keep_alive = head
+ .headers
+ .get(CONNECTION)
+ .map(connection_keep_alive)
+ .unwrap_or(false);
+
+ if !outgoing_is_keep_alive {
+ match head.version {
+ // If response is version 1.0 and keep-alive is not present in the response,
+ // disable keep-alive so the server closes the connection
+ Version::HTTP_10 => self.state.disable_keep_alive(),
+ // If response is version 1.1 and keep-alive is wanted, add
+ // Connection: keep-alive header when not present
+ Version::HTTP_11 => {
+ if self.state.wants_keep_alive() {
+ head.headers
+ .insert(CONNECTION, HeaderValue::from_static("keep-alive"));
+ }
+ }
+ _ => (),
+ }
+ }
+ }
+
+ // If we know the remote speaks an older version, we try to fix up any messages
+ // to work with our older peer.
+ fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) {
+ if let Version::HTTP_10 = self.state.version {
+ // Fixes response or connection when keep-alive header is not present
+ self.fix_keep_alive(head);
+ // If the remote only knows HTTP/1.0, we should force ourselves
+ // to do only speak HTTP/1.0 as well.
+ head.version = Version::HTTP_10;
+ }
+ // If the remote speaks HTTP/1.1, then it *should* be fine with
+ // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let
+ // the user's headers be.
+ }
+
+ pub(crate) fn write_body(&mut self, chunk: B) {
+ debug_assert!(self.can_write_body() && self.can_buffer_body());
+ // empty chunks should be discarded at Dispatcher level
+ debug_assert!(chunk.remaining() != 0);
+
+ let state = match self.state.writing {
+ Writing::Body(ref mut encoder) => {
+ self.io.buffer(encoder.encode(chunk));
+
+ if !encoder.is_eof() {
+ return;
+ }
+
+ if encoder.is_last() {
+ Writing::Closed
+ } else {
+ Writing::KeepAlive
+ }
+ }
+ _ => unreachable!("write_body invalid state: {:?}", self.state.writing),
+ };
+
+ self.state.writing = state;
+ }
+
+ pub(crate) fn write_body_and_end(&mut self, chunk: B) {
+ debug_assert!(self.can_write_body() && self.can_buffer_body());
+ // empty chunks should be discarded at Dispatcher level
+ debug_assert!(chunk.remaining() != 0);
+
+ let state = match self.state.writing {
+ Writing::Body(ref encoder) => {
+ let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf());
+ if can_keep_alive {
+ Writing::KeepAlive
+ } else {
+ Writing::Closed
+ }
+ }
+ _ => unreachable!("write_body invalid state: {:?}", self.state.writing),
+ };
+
+ self.state.writing = state;
+ }
+
+ pub(crate) fn end_body(&mut self) -> crate::Result<()> {
+ debug_assert!(self.can_write_body());
+
+ let encoder = match self.state.writing {
+ Writing::Body(ref mut enc) => enc,
+ _ => return Ok(()),
+ };
+
+ // end of stream, that means we should try to eof
+ match encoder.end() {
+ Ok(end) => {
+ if let Some(end) = end {
+ self.io.buffer(end);
+ }
+
+ self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
+ Writing::Closed
+ } else {
+ Writing::KeepAlive
+ };
+
+ Ok(())
+ }
+ Err(not_eof) => {
+ self.state.writing = Writing::Closed;
+ Err(crate::Error::new_body_write_aborted().with(not_eof))
+ }
+ }
+ }
+
+ // When we get a parse error, depending on what side we are, we might be able
+ // to write a response before closing the connection.
+ //
+ // - Client: there is nothing we can do
+ // - Server: if Response hasn't been written yet, we can send a 4xx response
+ fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> {
+ if let Writing::Init = self.state.writing {
+ if self.has_h2_prefix() {
+ return Err(crate::Error::new_version_h2());
+ }
+ if let Some(msg) = T::on_error(&err) {
+ // Drop the cached headers so as to not trigger a debug
+ // assert in `write_head`...
+ self.state.cached_headers.take();
+ self.write_head(msg, None);
+ self.state.error = Some(err);
+ return Ok(());
+ }
+ }
+
+ // fallback is pass the error back up
+ Err(err)
+ }
+
+ pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ ready!(Pin::new(&mut self.io).poll_flush(cx))?;
+ self.try_keep_alive(cx);
+ trace!("flushed({}): {:?}", T::LOG, self.state);
+ Poll::Ready(Ok(()))
+ }
+
+ pub(crate) fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) {
+ Ok(()) => {
+ trace!("shut down IO complete");
+ Poll::Ready(Ok(()))
+ }
+ Err(e) => {
+ debug!("error shutting down IO: {}", e);
+ Poll::Ready(Err(e))
+ }
+ }
+ }
+
+ /// If the read side can be cheaply drained, do so. Otherwise, close.
+ pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) {
+ if let Reading::Continue(ref decoder) = self.state.reading {
+ // skip sending the 100-continue
+ // just move forward to a read, in case a tiny body was included
+ self.state.reading = Reading::Body(decoder.clone());
+ }
+
+ let _ = self.poll_read_body(cx);
+
+ // If still in Reading::Body, just give up
+ match self.state.reading {
+ Reading::Init | Reading::KeepAlive => trace!("body drained"),
+ _ => self.close_read(),
+ }
+ }
+
+ pub(crate) fn close_read(&mut self) {
+ self.state.close_read();
+ }
+
+ pub(crate) fn close_write(&mut self) {
+ self.state.close_write();
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn disable_keep_alive(&mut self) {
+ if self.state.is_idle() {
+ trace!("disable_keep_alive; closing idle connection");
+ self.state.close();
+ } else {
+ trace!("disable_keep_alive; in-progress connection");
+ self.state.disable_keep_alive();
+ }
+ }
+
+ pub(crate) fn take_error(&mut self) -> crate::Result<()> {
+ if let Some(err) = self.state.error.take() {
+ Err(err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade {
+ trace!("{}: prepare possible HTTP upgrade", T::LOG);
+ self.state.prepare_upgrade()
+ }
+}
+
+impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Conn")
+ .field("state", &self.state)
+ .field("io", &self.io)
+ .finish()
+ }
+}
+
+// B and T are never pinned
+impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {}
+
+struct State {
+ allow_half_close: bool,
+ /// Re-usable HeaderMap to reduce allocating new ones.
+ cached_headers: Option<HeaderMap>,
+ /// If an error occurs when there wasn't a direct way to return it
+ /// back to the user, this is set.
+ error: Option<crate::Error>,
+ /// Current keep-alive status.
+ keep_alive: KA,
+ /// If mid-message, the HTTP Method that started it.
+ ///
+ /// This is used to know things such as if the message can include
+ /// a body or not.
+ method: Option<Method>,
+ h1_parser_config: ParserConfig,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout: Option<Duration>,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_fut: Option<Pin<Box<Sleep>>>,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_running: bool,
+ preserve_header_case: bool,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: bool,
+ title_case_headers: bool,
+ h09_responses: bool,
+ /// If set, called with each 1xx informational response received for
+ /// the current request. MUST be unset after a non-1xx response is
+ /// received.
+ #[cfg(feature = "ffi")]
+ on_informational: Option<crate::ffi::OnInformational>,
+ #[cfg(feature = "ffi")]
+ raw_headers: bool,
+ /// Set to true when the Dispatcher should poll read operations
+ /// again. See the `maybe_notify` method for more.
+ notify_read: bool,
+ /// State of allowed reads
+ reading: Reading,
+ /// State of allowed writes
+ writing: Writing,
+ /// An expected pending HTTP upgrade.
+ upgrade: Option<crate::upgrade::Pending>,
+ /// Either HTTP/1.0 or 1.1 connection
+ version: Version,
+}
+
+#[derive(Debug)]
+enum Reading {
+ Init,
+ Continue(Decoder),
+ Body(Decoder),
+ KeepAlive,
+ Closed,
+}
+
+enum Writing {
+ Init,
+ Body(Encoder),
+ KeepAlive,
+ Closed,
+}
+
+impl fmt::Debug for State {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let mut builder = f.debug_struct("State");
+ builder
+ .field("reading", &self.reading)
+ .field("writing", &self.writing)
+ .field("keep_alive", &self.keep_alive);
+
+ // Only show error field if it's interesting...
+ if let Some(ref error) = self.error {
+ builder.field("error", error);
+ }
+
+ if self.allow_half_close {
+ builder.field("allow_half_close", &true);
+ }
+
+ // Purposefully leaving off other fields..
+
+ builder.finish()
+ }
+}
+
+impl fmt::Debug for Writing {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Writing::Init => f.write_str("Init"),
+ Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(),
+ Writing::KeepAlive => f.write_str("KeepAlive"),
+ Writing::Closed => f.write_str("Closed"),
+ }
+ }
+}
+
+impl std::ops::BitAndAssign<bool> for KA {
+ fn bitand_assign(&mut self, enabled: bool) {
+ if !enabled {
+ trace!("remote disabling keep-alive");
+ *self = KA::Disabled;
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+enum KA {
+ Idle,
+ Busy,
+ Disabled,
+}
+
+impl Default for KA {
+ fn default() -> KA {
+ KA::Busy
+ }
+}
+
+impl KA {
+ fn idle(&mut self) {
+ *self = KA::Idle;
+ }
+
+ fn busy(&mut self) {
+ *self = KA::Busy;
+ }
+
+ fn disable(&mut self) {
+ *self = KA::Disabled;
+ }
+
+ fn status(&self) -> KA {
+ *self
+ }
+}
+
+impl State {
+ fn close(&mut self) {
+ trace!("State::close()");
+ self.reading = Reading::Closed;
+ self.writing = Writing::Closed;
+ self.keep_alive.disable();
+ }
+
+ fn close_read(&mut self) {
+ trace!("State::close_read()");
+ self.reading = Reading::Closed;
+ self.keep_alive.disable();
+ }
+
+ fn close_write(&mut self) {
+ trace!("State::close_write()");
+ self.writing = Writing::Closed;
+ self.keep_alive.disable();
+ }
+
+ fn wants_keep_alive(&self) -> bool {
+ if let KA::Disabled = self.keep_alive.status() {
+ false
+ } else {
+ true
+ }
+ }
+
+ fn try_keep_alive<T: Http1Transaction>(&mut self) {
+ match (&self.reading, &self.writing) {
+ (&Reading::KeepAlive, &Writing::KeepAlive) => {
+ if let KA::Busy = self.keep_alive.status() {
+ self.idle::<T>();
+ } else {
+ trace!(
+ "try_keep_alive({}): could keep-alive, but status = {:?}",
+ T::LOG,
+ self.keep_alive
+ );
+ self.close();
+ }
+ }
+ (&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => {
+ self.close()
+ }
+ _ => (),
+ }
+ }
+
+ fn disable_keep_alive(&mut self) {
+ self.keep_alive.disable()
+ }
+
+ fn busy(&mut self) {
+ if let KA::Disabled = self.keep_alive.status() {
+ return;
+ }
+ self.keep_alive.busy();
+ }
+
+ fn idle<T: Http1Transaction>(&mut self) {
+ debug_assert!(!self.is_idle(), "State::idle() called while idle");
+
+ self.method = None;
+ self.keep_alive.idle();
+
+ if !self.is_idle() {
+ self.close();
+ return;
+ }
+
+ self.reading = Reading::Init;
+ self.writing = Writing::Init;
+
+ // !T::should_read_first() means Client.
+ //
+ // If Client connection has just gone idle, the Dispatcher
+ // should try the poll loop one more time, so as to poll the
+ // pending requests stream.
+ if !T::should_read_first() {
+ self.notify_read = true;
+ }
+ }
+
+ fn is_idle(&self) -> bool {
+ matches!(self.keep_alive.status(), KA::Idle)
+ }
+
+ fn is_read_closed(&self) -> bool {
+ matches!(self.reading, Reading::Closed)
+ }
+
+ fn is_write_closed(&self) -> bool {
+ matches!(self.writing, Writing::Closed)
+ }
+
+ fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade {
+ let (tx, rx) = crate::upgrade::pending();
+ self.upgrade = Some(tx);
+ rx
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_read_head_short(b: &mut ::test::Bencher) {
+ use super::*;
+ let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n";
+ let len = s.len();
+ b.bytes = len as u64;
+
+ // an empty IO, we'll be skipping and using the read buffer anyways
+ let io = tokio_test::io::Builder::new().build();
+ let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io);
+ *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
+ conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
+
+ let rt = tokio::runtime::Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .unwrap();
+
+ b.iter(|| {
+ rt.block_on(futures_util::future::poll_fn(|cx| {
+ match conn.poll_read_head(cx) {
+ Poll::Ready(Some(Ok(x))) => {
+ ::test::black_box(&x);
+ let mut headers = x.0.headers;
+ headers.clear();
+ conn.state.cached_headers = Some(headers);
+ }
+ f => panic!("expected Ready(Some(Ok(..))): {:?}", f),
+ }
+
+ conn.io.read_buf_mut().reserve(1);
+ unsafe {
+ conn.io.read_buf_mut().set_len(len);
+ }
+ conn.state.reading = Reading::Init;
+ Poll::Ready(())
+ }));
+ });
+ }
+
+ /*
+ //TODO: rewrite these using dispatch... someday...
+ use futures::{Async, Future, Stream, Sink};
+ use futures::future;
+
+ use proto::{self, ClientTransaction, MessageHead, ServerTransaction};
+ use super::super::Encoder;
+ use mock::AsyncIo;
+
+ use super::{Conn, Decoder, Reading, Writing};
+ use ::uri::Uri;
+
+ use std::str::FromStr;
+
+ #[test]
+ fn test_conn_init_read() {
+ let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec();
+ let len = good_message.len();
+ let io = AsyncIo::new_buf(good_message, len);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+
+ match conn.poll().unwrap() {
+ Async::Ready(Some(Frame::Message { message, body: false })) => {
+ assert_eq!(message, MessageHead {
+ subject: ::proto::RequestLine(::Get, Uri::from_str("/").unwrap()),
+ .. MessageHead::default()
+ })
+ },
+ f => panic!("frame is not Frame::Message: {:?}", f)
+ }
+ }
+
+ #[test]
+ fn test_conn_parse_partial() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec();
+ let io = AsyncIo::new_buf(good_message, 10);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ assert!(conn.poll().unwrap().is_not_ready());
+ conn.io.io_mut().block_in(50);
+ let async = conn.poll().unwrap();
+ assert!(async.is_ready());
+ match async {
+ Async::Ready(Some(Frame::Message { .. })) => (),
+ f => panic!("frame is not Message: {:?}", f),
+ }
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_init_read_eof_idle() {
+ let io = AsyncIo::new_buf(vec![], 1);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.idle();
+
+ match conn.poll().unwrap() {
+ Async::Ready(None) => {},
+ other => panic!("frame is not None: {:?}", other)
+ }
+ }
+
+ #[test]
+ fn test_conn_init_read_eof_idle_partial_parse() {
+ let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.idle();
+
+ match conn.poll() {
+ Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {},
+ other => panic!("unexpected frame: {:?}", other)
+ }
+ }
+
+ #[test]
+ fn test_conn_init_read_eof_busy() {
+ let _: Result<(), ()> = future::lazy(|| {
+ // server ignores
+ let io = AsyncIo::new_eof();
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.busy();
+
+ match conn.poll().unwrap() {
+ Async::Ready(None) => {},
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ // client
+ let io = AsyncIo::new_eof();
+ let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io);
+ conn.state.busy();
+
+ match conn.poll() {
+ Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {},
+ other => panic!("unexpected frame: {:?}", other)
+ }
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_body_finish_read_eof() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_eof();
+ let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io);
+ conn.state.busy();
+ conn.state.writing = Writing::KeepAlive;
+ conn.state.reading = Reading::Body(Decoder::length(0));
+
+ match conn.poll() {
+ Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ // conn eofs, but tokio-proto will call poll() again, before calling flush()
+ // the conn eof in this case is perfectly fine
+
+ match conn.poll() {
+ Ok(Async::Ready(None)) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_message_empty_body_read_eof() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024);
+ let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io);
+ conn.state.busy();
+ conn.state.writing = Writing::KeepAlive;
+
+ match conn.poll() {
+ Ok(Async::Ready(Some(Frame::Message { body: false, .. }))) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ // conn eofs, but tokio-proto will call poll() again, before calling flush()
+ // the conn eof in this case is perfectly fine
+
+ match conn.poll() {
+ Ok(Async::Ready(None)) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_read_body_end() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.busy();
+
+ match conn.poll() {
+ Ok(Async::Ready(Some(Frame::Message { body: true, .. }))) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ match conn.poll() {
+ Ok(Async::Ready(Some(Frame::Body { chunk: Some(_) }))) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ // When the body is done, `poll` MUST return a `Body` frame with chunk set to `None`
+ match conn.poll() {
+ Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+
+ match conn.poll() {
+ Ok(Async::NotReady) => (),
+ other => panic!("unexpected frame: {:?}", other)
+ }
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_closed_read() {
+ let io = AsyncIo::new_buf(vec![], 0);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.close();
+
+ match conn.poll().unwrap() {
+ Async::Ready(None) => {},
+ other => panic!("frame is not None: {:?}", other)
+ }
+ }
+
+ #[test]
+ fn test_conn_body_write_length() {
+ let _ = pretty_env_logger::try_init();
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 0);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096;
+ conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64));
+
+ assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; max].into()) }).unwrap().is_ready());
+ assert!(!conn.can_buffer_body());
+
+ assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'b'; 1024 * 8].into()) }).unwrap().is_not_ready());
+
+ conn.io.io_mut().block_in(1024 * 3);
+ assert!(conn.poll_complete().unwrap().is_not_ready());
+ conn.io.io_mut().block_in(1024 * 3);
+ assert!(conn.poll_complete().unwrap().is_not_ready());
+ conn.io.io_mut().block_in(max * 2);
+ assert!(conn.poll_complete().unwrap().is_ready());
+
+ assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'c'; 1024 * 8].into()) }).unwrap().is_ready());
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_body_write_chunked() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 4096);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.writing = Writing::Body(Encoder::chunked());
+
+ assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready());
+ assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'x'; 8192].into()) }).unwrap().is_ready());
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_body_flush() {
+ let _: Result<(), ()> = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.writing = Writing::Body(Encoder::length(1024 * 1024));
+ assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready());
+ assert!(!conn.can_buffer_body());
+ conn.io.io_mut().block_in(1024 * 1024 * 5);
+ assert!(conn.poll_complete().unwrap().is_ready());
+ assert!(conn.can_buffer_body());
+ assert!(conn.io.io_mut().flushed());
+
+ Ok(())
+ }).wait();
+ }
+
+ #[test]
+ fn test_conn_parking() {
+ use std::sync::Arc;
+ use futures::executor::Notify;
+ use futures::executor::NotifyHandle;
+
+ struct Car {
+ permit: bool,
+ }
+ impl Notify for Car {
+ fn notify(&self, _id: usize) {
+ assert!(self.permit, "unparked without permit");
+ }
+ }
+
+ fn car(permit: bool) -> NotifyHandle {
+ Arc::new(Car {
+ permit: permit,
+ }).into()
+ }
+
+ // test that once writing is done, unparks
+ let f = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 4096);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.reading = Reading::KeepAlive;
+ assert!(conn.poll().unwrap().is_not_ready());
+
+ conn.state.writing = Writing::KeepAlive;
+ assert!(conn.poll_complete().unwrap().is_ready());
+ Ok::<(), ()>(())
+ });
+ ::futures::executor::spawn(f).poll_future_notify(&car(true), 0).unwrap();
+
+
+ // test that flushing when not waiting on read doesn't unpark
+ let f = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 4096);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.writing = Writing::KeepAlive;
+ assert!(conn.poll_complete().unwrap().is_ready());
+ Ok::<(), ()>(())
+ });
+ ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap();
+
+
+ // test that flushing and writing isn't done doesn't unpark
+ let f = future::lazy(|| {
+ let io = AsyncIo::new_buf(vec![], 4096);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.reading = Reading::KeepAlive;
+ assert!(conn.poll().unwrap().is_not_ready());
+ conn.state.writing = Writing::Body(Encoder::length(5_000));
+ assert!(conn.poll_complete().unwrap().is_ready());
+ Ok::<(), ()>(())
+ });
+ ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap();
+ }
+
+ #[test]
+ fn test_conn_closed_write() {
+ let io = AsyncIo::new_buf(vec![], 0);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.close();
+
+ match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) {
+ Err(_e) => {},
+ other => panic!("did not return Err: {:?}", other)
+ }
+
+ assert!(conn.state.is_write_closed());
+ }
+
+ #[test]
+ fn test_conn_write_empty_chunk() {
+ let io = AsyncIo::new_buf(vec![], 0);
+ let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io);
+ conn.state.writing = Writing::KeepAlive;
+
+ assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready());
+ assert!(conn.start_send(Frame::Body { chunk: Some(Vec::new().into()) }).unwrap().is_ready());
+ conn.start_send(Frame::Body { chunk: Some(vec![b'a'].into()) }).unwrap_err();
+ }
+ */
+}
diff --git a/third_party/rust/hyper/src/proto/h1/decode.rs b/third_party/rust/hyper/src/proto/h1/decode.rs
new file mode 100644
index 0000000000..1e3a38effc
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/decode.rs
@@ -0,0 +1,731 @@
+use std::error::Error as StdError;
+use std::fmt;
+use std::io;
+use std::usize;
+
+use bytes::Bytes;
+use tracing::{debug, trace};
+
+use crate::common::{task, Poll};
+
+use super::io::MemRead;
+use super::DecodedLength;
+
+use self::Kind::{Chunked, Eof, Length};
+
+/// Decoders to handle different Transfer-Encodings.
+///
+/// If a message body does not include a Transfer-Encoding, it *should*
+/// include a Content-Length header.
+#[derive(Clone, PartialEq)]
+pub(crate) struct Decoder {
+ kind: Kind,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+enum Kind {
+ /// A Reader used when a Content-Length header is passed with a positive integer.
+ Length(u64),
+ /// A Reader used when Transfer-Encoding is `chunked`.
+ Chunked(ChunkedState, u64),
+ /// A Reader used for responses that don't indicate a length or chunked.
+ ///
+ /// The bool tracks when EOF is seen on the transport.
+ ///
+ /// Note: This should only used for `Response`s. It is illegal for a
+ /// `Request` to be made with both `Content-Length` and
+ /// `Transfer-Encoding: chunked` missing, as explained from the spec:
+ ///
+ /// > If a Transfer-Encoding header field is present in a response and
+ /// > the chunked transfer coding is not the final encoding, the
+ /// > message body length is determined by reading the connection until
+ /// > it is closed by the server. If a Transfer-Encoding header field
+ /// > is present in a request and the chunked transfer coding is not
+ /// > the final encoding, the message body length cannot be determined
+ /// > reliably; the server MUST respond with the 400 (Bad Request)
+ /// > status code and then close the connection.
+ Eof(bool),
+}
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+enum ChunkedState {
+ Size,
+ SizeLws,
+ Extension,
+ SizeLf,
+ Body,
+ BodyCr,
+ BodyLf,
+ Trailer,
+ TrailerLf,
+ EndCr,
+ EndLf,
+ End,
+}
+
+impl Decoder {
+ // constructors
+
+ pub(crate) fn length(x: u64) -> Decoder {
+ Decoder {
+ kind: Kind::Length(x),
+ }
+ }
+
+ pub(crate) fn chunked() -> Decoder {
+ Decoder {
+ kind: Kind::Chunked(ChunkedState::Size, 0),
+ }
+ }
+
+ pub(crate) fn eof() -> Decoder {
+ Decoder {
+ kind: Kind::Eof(false),
+ }
+ }
+
+ pub(super) fn new(len: DecodedLength) -> Self {
+ match len {
+ DecodedLength::CHUNKED => Decoder::chunked(),
+ DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
+ length => Decoder::length(length.danger_len()),
+ }
+ }
+
+ // methods
+
+ pub(crate) fn is_eof(&self) -> bool {
+ matches!(self.kind, Length(0) | Chunked(ChunkedState::End, _) | Eof(true))
+ }
+
+ pub(crate) fn decode<R: MemRead>(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ body: &mut R,
+ ) -> Poll<Result<Bytes, io::Error>> {
+ trace!("decode; state={:?}", self.kind);
+ match self.kind {
+ Length(ref mut remaining) => {
+ if *remaining == 0 {
+ Poll::Ready(Ok(Bytes::new()))
+ } else {
+ let to_read = *remaining as usize;
+ let buf = ready!(body.read_mem(cx, to_read))?;
+ let num = buf.as_ref().len() as u64;
+ if num > *remaining {
+ *remaining = 0;
+ } else if num == 0 {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ IncompleteBody,
+ )));
+ } else {
+ *remaining -= num;
+ }
+ Poll::Ready(Ok(buf))
+ }
+ }
+ Chunked(ref mut state, ref mut size) => {
+ loop {
+ let mut buf = None;
+ // advances the chunked state
+ *state = ready!(state.step(cx, body, size, &mut buf))?;
+ if *state == ChunkedState::End {
+ trace!("end of chunked");
+ return Poll::Ready(Ok(Bytes::new()));
+ }
+ if let Some(buf) = buf {
+ return Poll::Ready(Ok(buf));
+ }
+ }
+ }
+ Eof(ref mut is_eof) => {
+ if *is_eof {
+ Poll::Ready(Ok(Bytes::new()))
+ } else {
+ // 8192 chosen because its about 2 packets, there probably
+ // won't be that much available, so don't have MemReaders
+ // allocate buffers to big
+ body.read_mem(cx, 8192).map_ok(|slice| {
+ *is_eof = slice.is_empty();
+ slice
+ })
+ }
+ }
+ }
+ }
+
+ #[cfg(test)]
+ async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> {
+ futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await
+ }
+}
+
+impl fmt::Debug for Decoder {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Debug::fmt(&self.kind, f)
+ }
+}
+
+macro_rules! byte (
+ ($rdr:ident, $cx:expr) => ({
+ let buf = ready!($rdr.read_mem($cx, 1))?;
+ if !buf.is_empty() {
+ buf[0]
+ } else {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof,
+ "unexpected EOF during chunk size line")));
+ }
+ })
+);
+
+impl ChunkedState {
+ fn step<R: MemRead>(
+ &self,
+ cx: &mut task::Context<'_>,
+ body: &mut R,
+ size: &mut u64,
+ buf: &mut Option<Bytes>,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ use self::ChunkedState::*;
+ match *self {
+ Size => ChunkedState::read_size(cx, body, size),
+ SizeLws => ChunkedState::read_size_lws(cx, body),
+ Extension => ChunkedState::read_extension(cx, body),
+ SizeLf => ChunkedState::read_size_lf(cx, body, *size),
+ Body => ChunkedState::read_body(cx, body, size, buf),
+ BodyCr => ChunkedState::read_body_cr(cx, body),
+ BodyLf => ChunkedState::read_body_lf(cx, body),
+ Trailer => ChunkedState::read_trailer(cx, body),
+ TrailerLf => ChunkedState::read_trailer_lf(cx, body),
+ EndCr => ChunkedState::read_end_cr(cx, body),
+ EndLf => ChunkedState::read_end_lf(cx, body),
+ End => Poll::Ready(Ok(ChunkedState::End)),
+ }
+ }
+ fn read_size<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ size: &mut u64,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("Read chunk hex size");
+
+ macro_rules! or_overflow {
+ ($e:expr) => (
+ match $e {
+ Some(val) => val,
+ None => return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "invalid chunk size: overflow",
+ ))),
+ }
+ )
+ }
+
+ let radix = 16;
+ match byte!(rdr, cx) {
+ b @ b'0'..=b'9' => {
+ *size = or_overflow!(size.checked_mul(radix));
+ *size = or_overflow!(size.checked_add((b - b'0') as u64));
+ }
+ b @ b'a'..=b'f' => {
+ *size = or_overflow!(size.checked_mul(radix));
+ *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
+ }
+ b @ b'A'..=b'F' => {
+ *size = or_overflow!(size.checked_mul(radix));
+ *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
+ }
+ b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
+ b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
+ b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
+ _ => {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk size line: Invalid Size",
+ )));
+ }
+ }
+ Poll::Ready(Ok(ChunkedState::Size))
+ }
+ fn read_size_lws<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("read_size_lws");
+ match byte!(rdr, cx) {
+ // LWS can follow the chunk size, but no more digits can come
+ b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
+ b';' => Poll::Ready(Ok(ChunkedState::Extension)),
+ b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk size linear white space",
+ ))),
+ }
+ }
+ fn read_extension<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("read_extension");
+ // We don't care about extensions really at all. Just ignore them.
+ // They "end" at the next CRLF.
+ //
+ // However, some implementations may not check for the CR, so to save
+ // them from themselves, we reject extensions containing plain LF as
+ // well.
+ match byte!(rdr, cx) {
+ b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
+ b'\n' => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "invalid chunk extension contains newline",
+ ))),
+ _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
+ }
+ }
+ fn read_size_lf<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ size: u64,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("Chunk size is {:?}", size);
+ match byte!(rdr, cx) {
+ b'\n' => {
+ if size == 0 {
+ Poll::Ready(Ok(ChunkedState::EndCr))
+ } else {
+ debug!("incoming chunked header: {0:#X} ({0} bytes)", size);
+ Poll::Ready(Ok(ChunkedState::Body))
+ }
+ }
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk size LF",
+ ))),
+ }
+ }
+
+ fn read_body<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ rem: &mut u64,
+ buf: &mut Option<Bytes>,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("Chunked read, remaining={:?}", rem);
+
+ // cap remaining bytes at the max capacity of usize
+ let rem_cap = match *rem {
+ r if r > usize::MAX as u64 => usize::MAX,
+ r => r as usize,
+ };
+
+ let to_read = rem_cap;
+ let slice = ready!(rdr.read_mem(cx, to_read))?;
+ let count = slice.len();
+
+ if count == 0 {
+ *rem = 0;
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ IncompleteBody,
+ )));
+ }
+ *buf = Some(slice);
+ *rem -= count as u64;
+
+ if *rem > 0 {
+ Poll::Ready(Ok(ChunkedState::Body))
+ } else {
+ Poll::Ready(Ok(ChunkedState::BodyCr))
+ }
+ }
+ fn read_body_cr<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ match byte!(rdr, cx) {
+ b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk body CR",
+ ))),
+ }
+ }
+ fn read_body_lf<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ match byte!(rdr, cx) {
+ b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk body LF",
+ ))),
+ }
+ }
+
+ fn read_trailer<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ trace!("read_trailer");
+ match byte!(rdr, cx) {
+ b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)),
+ _ => Poll::Ready(Ok(ChunkedState::Trailer)),
+ }
+ }
+ fn read_trailer_lf<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ match byte!(rdr, cx) {
+ b'\n' => Poll::Ready(Ok(ChunkedState::EndCr)),
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid trailer end LF",
+ ))),
+ }
+ }
+
+ fn read_end_cr<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ match byte!(rdr, cx) {
+ b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
+ _ => Poll::Ready(Ok(ChunkedState::Trailer)),
+ }
+ }
+ fn read_end_lf<R: MemRead>(
+ cx: &mut task::Context<'_>,
+ rdr: &mut R,
+ ) -> Poll<Result<ChunkedState, io::Error>> {
+ match byte!(rdr, cx) {
+ b'\n' => Poll::Ready(Ok(ChunkedState::End)),
+ _ => Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Invalid chunk end LF",
+ ))),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct IncompleteBody;
+
+impl fmt::Display for IncompleteBody {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "end of file before message length reached")
+ }
+}
+
+impl StdError for IncompleteBody {}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::pin::Pin;
+ use std::time::Duration;
+ use tokio::io::{AsyncRead, ReadBuf};
+
+ impl<'a> MemRead for &'a [u8] {
+ fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
+ let n = std::cmp::min(len, self.len());
+ if n > 0 {
+ let (a, b) = self.split_at(n);
+ let buf = Bytes::copy_from_slice(a);
+ *self = b;
+ Poll::Ready(Ok(buf))
+ } else {
+ Poll::Ready(Ok(Bytes::new()))
+ }
+ }
+ }
+
+ impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) {
+ fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
+ let mut v = vec![0; len];
+ let mut buf = ReadBuf::new(&mut v);
+ ready!(Pin::new(self).poll_read(cx, &mut buf)?);
+ Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled())))
+ }
+ }
+
+ #[cfg(feature = "nightly")]
+ impl MemRead for Bytes {
+ fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
+ let n = std::cmp::min(len, self.len());
+ let ret = self.split_to(n);
+ Poll::Ready(Ok(ret))
+ }
+ }
+
+ /*
+ use std::io;
+ use std::io::Write;
+ use super::Decoder;
+ use super::ChunkedState;
+ use futures::{Async, Poll};
+ use bytes::{BytesMut, Bytes};
+ use crate::mock::AsyncIo;
+ */
+
+ #[tokio::test]
+ async fn test_read_chunk_size() {
+ use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
+
+ async fn read(s: &str) -> u64 {
+ let mut state = ChunkedState::Size;
+ let rdr = &mut s.as_bytes();
+ let mut size = 0;
+ loop {
+ let result =
+ futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None))
+ .await;
+ let desc = format!("read_size failed for {:?}", s);
+ state = result.expect(desc.as_str());
+ if state == ChunkedState::Body || state == ChunkedState::EndCr {
+ break;
+ }
+ }
+ size
+ }
+
+ async fn read_err(s: &str, expected_err: io::ErrorKind) {
+ let mut state = ChunkedState::Size;
+ let rdr = &mut s.as_bytes();
+ let mut size = 0;
+ loop {
+ let result =
+ futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None))
+ .await;
+ state = match result {
+ Ok(s) => s,
+ Err(e) => {
+ assert!(
+ expected_err == e.kind(),
+ "Reading {:?}, expected {:?}, but got {:?}",
+ s,
+ expected_err,
+ e.kind()
+ );
+ return;
+ }
+ };
+ if state == ChunkedState::Body || state == ChunkedState::End {
+ panic!("Was Ok. Expected Err for {:?}", s);
+ }
+ }
+ }
+
+ assert_eq!(1, read("1\r\n").await);
+ assert_eq!(1, read("01\r\n").await);
+ assert_eq!(0, read("0\r\n").await);
+ assert_eq!(0, read("00\r\n").await);
+ assert_eq!(10, read("A\r\n").await);
+ assert_eq!(10, read("a\r\n").await);
+ assert_eq!(255, read("Ff\r\n").await);
+ assert_eq!(255, read("Ff \r\n").await);
+ // Missing LF or CRLF
+ read_err("F\rF", InvalidInput).await;
+ read_err("F", UnexpectedEof).await;
+ // Invalid hex digit
+ read_err("X\r\n", InvalidInput).await;
+ read_err("1X\r\n", InvalidInput).await;
+ read_err("-\r\n", InvalidInput).await;
+ read_err("-1\r\n", InvalidInput).await;
+ // Acceptable (if not fully valid) extensions do not influence the size
+ assert_eq!(1, read("1;extension\r\n").await);
+ assert_eq!(10, read("a;ext name=value\r\n").await);
+ assert_eq!(1, read("1;extension;extension2\r\n").await);
+ assert_eq!(1, read("1;;; ;\r\n").await);
+ assert_eq!(2, read("2; extension...\r\n").await);
+ assert_eq!(3, read("3 ; extension=123\r\n").await);
+ assert_eq!(3, read("3 ;\r\n").await);
+ assert_eq!(3, read("3 ; \r\n").await);
+ // Invalid extensions cause an error
+ read_err("1 invalid extension\r\n", InvalidInput).await;
+ read_err("1 A\r\n", InvalidInput).await;
+ read_err("1;no CRLF", UnexpectedEof).await;
+ read_err("1;reject\nnewlines\r\n", InvalidData).await;
+ // Overflow
+ read_err("f0000000000000003\r\n", InvalidData).await;
+ }
+
+ #[tokio::test]
+ async fn test_read_sized_early_eof() {
+ let mut bytes = &b"foo bar"[..];
+ let mut decoder = Decoder::length(10);
+ assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7);
+ let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
+ assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
+ }
+
+ #[tokio::test]
+ async fn test_read_chunked_early_eof() {
+ let mut bytes = &b"\
+ 9\r\n\
+ foo bar\
+ "[..];
+ let mut decoder = Decoder::chunked();
+ assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7);
+ let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
+ assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
+ }
+
+ #[tokio::test]
+ async fn test_read_chunked_single_read() {
+ let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
+ let buf = Decoder::chunked()
+ .decode_fut(&mut mock_buf)
+ .await
+ .expect("decode");
+ assert_eq!(16, buf.len());
+ let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
+ assert_eq!("1234567890abcdef", &result);
+ }
+
+ #[tokio::test]
+ async fn test_read_chunked_trailer_with_missing_lf() {
+ let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..];
+ let mut decoder = Decoder::chunked();
+ decoder.decode_fut(&mut mock_buf).await.expect("decode");
+ let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err();
+ assert_eq!(e.kind(), io::ErrorKind::InvalidInput);
+ }
+
+ #[tokio::test]
+ async fn test_read_chunked_after_eof() {
+ let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..];
+ let mut decoder = Decoder::chunked();
+
+ // normal read
+ let buf = decoder.decode_fut(&mut mock_buf).await.unwrap();
+ assert_eq!(16, buf.len());
+ let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
+ assert_eq!("1234567890abcdef", &result);
+
+ // eof read
+ let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode");
+ assert_eq!(0, buf.len());
+
+ // ensure read after eof also returns eof
+ let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode");
+ assert_eq!(0, buf.len());
+ }
+
+ // perform an async read using a custom buffer size and causing a blocking
+ // read at the specified byte
+ async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String {
+ let mut outs = Vec::new();
+
+ let mut ins = if block_at == 0 {
+ tokio_test::io::Builder::new()
+ .wait(Duration::from_millis(10))
+ .read(content)
+ .build()
+ } else {
+ tokio_test::io::Builder::new()
+ .read(&content[..block_at])
+ .wait(Duration::from_millis(10))
+ .read(&content[block_at..])
+ .build()
+ };
+
+ let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin);
+
+ loop {
+ let buf = decoder
+ .decode_fut(&mut ins)
+ .await
+ .expect("unexpected decode error");
+ if buf.is_empty() {
+ break; // eof
+ }
+ outs.extend(buf.as_ref());
+ }
+
+ String::from_utf8(outs).expect("decode String")
+ }
+
+ // iterate over the different ways that this async read could go.
+ // tests blocking a read at each byte along the content - The shotgun approach
+ async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) {
+ let content_len = content.len();
+ for block_at in 0..content_len {
+ let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await;
+ assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at);
+ }
+ }
+
+ #[tokio::test]
+ async fn test_read_length_async() {
+ let content = "foobar";
+ all_async_cases(content, content, Decoder::length(content.len() as u64)).await;
+ }
+
+ #[tokio::test]
+ async fn test_read_chunked_async() {
+ let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n";
+ let expected = "foobar";
+ all_async_cases(content, expected, Decoder::chunked()).await;
+ }
+
+ #[tokio::test]
+ async fn test_read_eof_async() {
+ let content = "foobar";
+ all_async_cases(content, content, Decoder::eof()).await;
+ }
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
+ let rt = new_runtime();
+
+ const LEN: usize = 1024;
+ let mut vec = Vec::new();
+ vec.extend(format!("{:x}\r\n", LEN).as_bytes());
+ vec.extend(&[0; LEN][..]);
+ vec.extend(b"\r\n");
+ let content = Bytes::from(vec);
+
+ b.bytes = LEN as u64;
+
+ b.iter(|| {
+ let mut decoder = Decoder::chunked();
+ rt.block_on(async {
+ let mut raw = content.clone();
+ let chunk = decoder.decode_fut(&mut raw).await.unwrap();
+ assert_eq!(chunk.len(), LEN);
+ });
+ });
+ }
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_decode_length_1kb(b: &mut test::Bencher) {
+ let rt = new_runtime();
+
+ const LEN: usize = 1024;
+ let content = Bytes::from(&[0; LEN][..]);
+ b.bytes = LEN as u64;
+
+ b.iter(|| {
+ let mut decoder = Decoder::length(LEN as u64);
+ rt.block_on(async {
+ let mut raw = content.clone();
+ let chunk = decoder.decode_fut(&mut raw).await.unwrap();
+ assert_eq!(chunk.len(), LEN);
+ });
+ });
+ }
+
+ #[cfg(feature = "nightly")]
+ fn new_runtime() -> tokio::runtime::Runtime {
+ tokio::runtime::Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .expect("rt build")
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h1/dispatch.rs b/third_party/rust/hyper/src/proto/h1/dispatch.rs
new file mode 100644
index 0000000000..677131bfdd
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/dispatch.rs
@@ -0,0 +1,750 @@
+use std::error::Error as StdError;
+
+use bytes::{Buf, Bytes};
+use http::Request;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tracing::{debug, trace};
+
+use super::{Http1Transaction, Wants};
+use crate::body::{Body, DecodedLength, HttpBody};
+use crate::common::{task, Future, Pin, Poll, Unpin};
+use crate::proto::{
+ BodyLength, Conn, Dispatched, MessageHead, RequestHead,
+};
+use crate::upgrade::OnUpgrade;
+
+pub(crate) struct Dispatcher<D, Bs: HttpBody, I, T> {
+ conn: Conn<I, Bs::Data, T>,
+ dispatch: D,
+ body_tx: Option<crate::body::Sender>,
+ body_rx: Pin<Box<Option<Bs>>>,
+ is_closing: bool,
+}
+
+pub(crate) trait Dispatch {
+ type PollItem;
+ type PollBody;
+ type PollError;
+ type RecvItem;
+ fn poll_msg(
+ self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
+ fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>;
+ fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>;
+ fn should_poll(&self) -> bool;
+}
+
+cfg_server! {
+ use crate::service::HttpService;
+
+ pub(crate) struct Server<S: HttpService<B>, B> {
+ in_flight: Pin<Box<Option<S::Future>>>,
+ pub(crate) service: S,
+ }
+}
+
+cfg_client! {
+ pin_project_lite::pin_project! {
+ pub(crate) struct Client<B> {
+ callback: Option<crate::client::dispatch::Callback<Request<B>, http::Response<Body>>>,
+ #[pin]
+ rx: ClientRx<B>,
+ rx_closed: bool,
+ }
+ }
+
+ type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, http::Response<Body>>;
+}
+
+impl<D, Bs, I, T> Dispatcher<D, Bs, I, T>
+where
+ D: Dispatch<
+ PollItem = MessageHead<T::Outgoing>,
+ PollBody = Bs,
+ RecvItem = MessageHead<T::Incoming>,
+ > + Unpin,
+ D::PollError: Into<Box<dyn StdError + Send + Sync>>,
+ I: AsyncRead + AsyncWrite + Unpin,
+ T: Http1Transaction + Unpin,
+ Bs: HttpBody + 'static,
+ Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
+ pub(crate) fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self {
+ Dispatcher {
+ conn,
+ dispatch,
+ body_tx: None,
+ body_rx: Box::pin(None),
+ is_closing: false,
+ }
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn disable_keep_alive(&mut self) {
+ self.conn.disable_keep_alive();
+ if self.conn.is_write_closed() {
+ self.close();
+ }
+ }
+
+ pub(crate) fn into_inner(self) -> (I, Bytes, D) {
+ let (io, buf) = self.conn.into_inner();
+ (io, buf, self.dispatch)
+ }
+
+ /// Run this dispatcher until HTTP says this connection is done,
+ /// but don't call `AsyncWrite::shutdown` on the underlying IO.
+ ///
+ /// This is useful for old-style HTTP upgrades, but ignores
+ /// newer-style upgrade API.
+ pub(crate) fn poll_without_shutdown(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<crate::Result<()>>
+ where
+ Self: Unpin,
+ {
+ Pin::new(self).poll_catch(cx, false).map_ok(|ds| {
+ if let Dispatched::Upgrade(pending) = ds {
+ pending.manual();
+ }
+ })
+ }
+
+ fn poll_catch(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ should_shutdown: bool,
+ ) -> Poll<crate::Result<Dispatched>> {
+ Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| {
+ // An error means we're shutting down either way.
+ // We just try to give the error to the user,
+ // and close the connection with an Ok. If we
+ // cannot give it to the user, then return the Err.
+ self.dispatch.recv_msg(Err(e))?;
+ Ok(Dispatched::Shutdown)
+ }))
+ }
+
+ fn poll_inner(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ should_shutdown: bool,
+ ) -> Poll<crate::Result<Dispatched>> {
+ T::update_date();
+
+ ready!(self.poll_loop(cx))?;
+
+ if self.is_done() {
+ if let Some(pending) = self.conn.pending_upgrade() {
+ self.conn.take_error()?;
+ return Poll::Ready(Ok(Dispatched::Upgrade(pending)));
+ } else if should_shutdown {
+ ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?;
+ }
+ self.conn.take_error()?;
+ Poll::Ready(Ok(Dispatched::Shutdown))
+ } else {
+ Poll::Pending
+ }
+ }
+
+ fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ // Limit the looping on this connection, in case it is ready far too
+ // often, so that other futures don't starve.
+ //
+ // 16 was chosen arbitrarily, as that is number of pipelined requests
+ // benchmarks often use. Perhaps it should be a config option instead.
+ for _ in 0..16 {
+ let _ = self.poll_read(cx)?;
+ let _ = self.poll_write(cx)?;
+ let _ = self.poll_flush(cx)?;
+
+ // This could happen if reading paused before blocking on IO,
+ // such as getting to the end of a framed message, but then
+ // writing/flushing set the state back to Init. In that case,
+ // if the read buffer still had bytes, we'd want to try poll_read
+ // again, or else we wouldn't ever be woken up again.
+ //
+ // Using this instead of task::current() and notify() inside
+ // the Conn is noticeably faster in pipelined benchmarks.
+ if !self.conn.wants_read_again() {
+ //break;
+ return Poll::Ready(Ok(()));
+ }
+ }
+
+ trace!("poll_loop yielding (self = {:p})", self);
+
+ task::yield_now(cx).map(|never| match never {})
+ }
+
+ fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ loop {
+ if self.is_closing {
+ return Poll::Ready(Ok(()));
+ } else if self.conn.can_read_head() {
+ ready!(self.poll_read_head(cx))?;
+ } else if let Some(mut body) = self.body_tx.take() {
+ if self.conn.can_read_body() {
+ match body.poll_ready(cx) {
+ Poll::Ready(Ok(())) => (),
+ Poll::Pending => {
+ self.body_tx = Some(body);
+ return Poll::Pending;
+ }
+ Poll::Ready(Err(_canceled)) => {
+ // user doesn't care about the body
+ // so we should stop reading
+ trace!("body receiver dropped before eof, draining or closing");
+ self.conn.poll_drain_or_close_read(cx);
+ continue;
+ }
+ }
+ match self.conn.poll_read_body(cx) {
+ Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) {
+ Ok(()) => {
+ self.body_tx = Some(body);
+ }
+ Err(_canceled) => {
+ if self.conn.can_read_body() {
+ trace!("body receiver dropped before eof, closing");
+ self.conn.close_read();
+ }
+ }
+ },
+ Poll::Ready(None) => {
+ // just drop, the body will close automatically
+ }
+ Poll::Pending => {
+ self.body_tx = Some(body);
+ return Poll::Pending;
+ }
+ Poll::Ready(Some(Err(e))) => {
+ body.send_error(crate::Error::new_body(e));
+ }
+ }
+ } else {
+ // just drop, the body will close automatically
+ }
+ } else {
+ return self.conn.poll_read_keep_alive(cx);
+ }
+ }
+ }
+
+ fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ // can dispatch receive, or does it still care about, an incoming message?
+ match ready!(self.dispatch.poll_ready(cx)) {
+ Ok(()) => (),
+ Err(()) => {
+ trace!("dispatch no longer receiving messages");
+ self.close();
+ return Poll::Ready(Ok(()));
+ }
+ }
+ // dispatch is ready for a message, try to read one
+ match ready!(self.conn.poll_read_head(cx)) {
+ Some(Ok((mut head, body_len, wants))) => {
+ let body = match body_len {
+ DecodedLength::ZERO => Body::empty(),
+ other => {
+ let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT));
+ self.body_tx = Some(tx);
+ rx
+ }
+ };
+ if wants.contains(Wants::UPGRADE) {
+ let upgrade = self.conn.on_upgrade();
+ debug_assert!(!upgrade.is_none(), "empty upgrade");
+ debug_assert!(head.extensions.get::<OnUpgrade>().is_none(), "OnUpgrade already set");
+ head.extensions.insert(upgrade);
+ }
+ self.dispatch.recv_msg(Ok((head, body)))?;
+ Poll::Ready(Ok(()))
+ }
+ Some(Err(err)) => {
+ debug!("read_head error: {}", err);
+ self.dispatch.recv_msg(Err(err))?;
+ // if here, the dispatcher gave the user the error
+ // somewhere else. we still need to shutdown, but
+ // not as a second error.
+ self.close();
+ Poll::Ready(Ok(()))
+ }
+ None => {
+ // read eof, the write side will have been closed too unless
+ // allow_read_close was set to true, in which case just do
+ // nothing...
+ debug_assert!(self.conn.is_read_closed());
+ if self.conn.is_write_closed() {
+ self.close();
+ }
+ Poll::Ready(Ok(()))
+ }
+ }
+ }
+
+ fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ loop {
+ if self.is_closing {
+ return Poll::Ready(Ok(()));
+ } else if self.body_rx.is_none()
+ && self.conn.can_write_head()
+ && self.dispatch.should_poll()
+ {
+ if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) {
+ let (head, mut body) = msg.map_err(crate::Error::new_user_service)?;
+
+ // Check if the body knows its full data immediately.
+ //
+ // If so, we can skip a bit of bookkeeping that streaming
+ // bodies need to do.
+ if let Some(full) = crate::body::take_full_data(&mut body) {
+ self.conn.write_full_msg(head, full);
+ return Poll::Ready(Ok(()));
+ }
+
+ let body_type = if body.is_end_stream() {
+ self.body_rx.set(None);
+ None
+ } else {
+ let btype = body
+ .size_hint()
+ .exact()
+ .map(BodyLength::Known)
+ .or_else(|| Some(BodyLength::Unknown));
+ self.body_rx.set(Some(body));
+ btype
+ };
+ self.conn.write_head(head, body_type);
+ } else {
+ self.close();
+ return Poll::Ready(Ok(()));
+ }
+ } else if !self.conn.can_buffer_body() {
+ ready!(self.poll_flush(cx))?;
+ } else {
+ // A new scope is needed :(
+ if let (Some(mut body), clear_body) =
+ OptGuard::new(self.body_rx.as_mut()).guard_mut()
+ {
+ debug_assert!(!*clear_body, "opt guard defaults to keeping body");
+ if !self.conn.can_write_body() {
+ trace!(
+ "no more write body allowed, user body is_end_stream = {}",
+ body.is_end_stream(),
+ );
+ *clear_body = true;
+ continue;
+ }
+
+ let item = ready!(body.as_mut().poll_data(cx));
+ if let Some(item) = item {
+ let chunk = item.map_err(|e| {
+ *clear_body = true;
+ crate::Error::new_user_body(e)
+ })?;
+ let eos = body.is_end_stream();
+ if eos {
+ *clear_body = true;
+ if chunk.remaining() == 0 {
+ trace!("discarding empty chunk");
+ self.conn.end_body()?;
+ } else {
+ self.conn.write_body_and_end(chunk);
+ }
+ } else {
+ if chunk.remaining() == 0 {
+ trace!("discarding empty chunk");
+ continue;
+ }
+ self.conn.write_body(chunk);
+ }
+ } else {
+ *clear_body = true;
+ self.conn.end_body()?;
+ }
+ } else {
+ return Poll::Pending;
+ }
+ }
+ }
+ }
+
+ fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ self.conn.poll_flush(cx).map_err(|err| {
+ debug!("error writing: {}", err);
+ crate::Error::new_body_write(err)
+ })
+ }
+
+ fn close(&mut self) {
+ self.is_closing = true;
+ self.conn.close_read();
+ self.conn.close_write();
+ }
+
+ fn is_done(&self) -> bool {
+ if self.is_closing {
+ return true;
+ }
+
+ let read_done = self.conn.is_read_closed();
+
+ if !T::should_read_first() && read_done {
+ // a client that cannot read may was well be done.
+ true
+ } else {
+ let write_done = self.conn.is_write_closed()
+ || (!self.dispatch.should_poll() && self.body_rx.is_none());
+ read_done && write_done
+ }
+ }
+}
+
+impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T>
+where
+ D: Dispatch<
+ PollItem = MessageHead<T::Outgoing>,
+ PollBody = Bs,
+ RecvItem = MessageHead<T::Incoming>,
+ > + Unpin,
+ D::PollError: Into<Box<dyn StdError + Send + Sync>>,
+ I: AsyncRead + AsyncWrite + Unpin,
+ T: Http1Transaction + Unpin,
+ Bs: HttpBody + 'static,
+ Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
+ type Output = crate::Result<Dispatched>;
+
+ #[inline]
+ fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ self.poll_catch(cx, true)
+ }
+}
+
+// ===== impl OptGuard =====
+
+/// A drop guard to allow a mutable borrow of an Option while being able to
+/// set whether the `Option` should be cleared on drop.
+struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool);
+
+impl<'a, T> OptGuard<'a, T> {
+ fn new(pin: Pin<&'a mut Option<T>>) -> Self {
+ OptGuard(pin, false)
+ }
+
+ fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) {
+ (self.0.as_mut().as_pin_mut(), &mut self.1)
+ }
+}
+
+impl<'a, T> Drop for OptGuard<'a, T> {
+ fn drop(&mut self) {
+ if self.1 {
+ self.0.set(None);
+ }
+ }
+}
+
+// ===== impl Server =====
+
+cfg_server! {
+ impl<S, B> Server<S, B>
+ where
+ S: HttpService<B>,
+ {
+ pub(crate) fn new(service: S) -> Server<S, B> {
+ Server {
+ in_flight: Box::pin(None),
+ service,
+ }
+ }
+
+ pub(crate) fn into_service(self) -> S {
+ self.service
+ }
+ }
+
+ // Service is never pinned
+ impl<S: HttpService<B>, B> Unpin for Server<S, B> {}
+
+ impl<S, Bs> Dispatch for Server<S, Body>
+ where
+ S: HttpService<Body, ResBody = Bs>,
+ S::Error: Into<Box<dyn StdError + Send + Sync>>,
+ Bs: HttpBody,
+ {
+ type PollItem = MessageHead<http::StatusCode>;
+ type PollBody = Bs;
+ type PollError = S::Error;
+ type RecvItem = RequestHead;
+
+ fn poll_msg(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> {
+ let mut this = self.as_mut();
+ let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() {
+ let resp = ready!(fut.as_mut().poll(cx)?);
+ let (parts, body) = resp.into_parts();
+ let head = MessageHead {
+ version: parts.version,
+ subject: parts.status,
+ headers: parts.headers,
+ extensions: parts.extensions,
+ };
+ Poll::Ready(Some(Ok((head, body))))
+ } else {
+ unreachable!("poll_msg shouldn't be called if no inflight");
+ };
+
+ // Since in_flight finished, remove it
+ this.in_flight.set(None);
+ ret
+ }
+
+ fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> {
+ let (msg, body) = msg?;
+ let mut req = Request::new(body);
+ *req.method_mut() = msg.subject.0;
+ *req.uri_mut() = msg.subject.1;
+ *req.headers_mut() = msg.headers;
+ *req.version_mut() = msg.version;
+ *req.extensions_mut() = msg.extensions;
+ let fut = self.service.call(req);
+ self.in_flight.set(Some(fut));
+ Ok(())
+ }
+
+ fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> {
+ if self.in_flight.is_some() {
+ Poll::Pending
+ } else {
+ self.service.poll_ready(cx).map_err(|_e| {
+ // FIXME: return error value.
+ trace!("service closed");
+ })
+ }
+ }
+
+ fn should_poll(&self) -> bool {
+ self.in_flight.is_some()
+ }
+ }
+}
+
+// ===== impl Client =====
+
+cfg_client! {
+ impl<B> Client<B> {
+ pub(crate) fn new(rx: ClientRx<B>) -> Client<B> {
+ Client {
+ callback: None,
+ rx,
+ rx_closed: false,
+ }
+ }
+ }
+
+ impl<B> Dispatch for Client<B>
+ where
+ B: HttpBody,
+ {
+ type PollItem = RequestHead;
+ type PollBody = B;
+ type PollError = crate::common::Never;
+ type RecvItem = crate::proto::ResponseHead;
+
+ fn poll_msg(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), crate::common::Never>>> {
+ let mut this = self.as_mut();
+ debug_assert!(!this.rx_closed);
+ match this.rx.poll_recv(cx) {
+ Poll::Ready(Some((req, mut cb))) => {
+ // check that future hasn't been canceled already
+ match cb.poll_canceled(cx) {
+ Poll::Ready(()) => {
+ trace!("request canceled");
+ Poll::Ready(None)
+ }
+ Poll::Pending => {
+ let (parts, body) = req.into_parts();
+ let head = RequestHead {
+ version: parts.version,
+ subject: crate::proto::RequestLine(parts.method, parts.uri),
+ headers: parts.headers,
+ extensions: parts.extensions,
+ };
+ this.callback = Some(cb);
+ Poll::Ready(Some(Ok((head, body))))
+ }
+ }
+ }
+ Poll::Ready(None) => {
+ // user has dropped sender handle
+ trace!("client tx closed");
+ this.rx_closed = true;
+ Poll::Ready(None)
+ }
+ Poll::Pending => Poll::Pending,
+ }
+ }
+
+ fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> {
+ match msg {
+ Ok((msg, body)) => {
+ if let Some(cb) = self.callback.take() {
+ let res = msg.into_response(body);
+ cb.send(Ok(res));
+ Ok(())
+ } else {
+ // Getting here is likely a bug! An error should have happened
+ // in Conn::require_empty_read() before ever parsing a
+ // full message!
+ Err(crate::Error::new_unexpected_message())
+ }
+ }
+ Err(err) => {
+ if let Some(cb) = self.callback.take() {
+ cb.send(Err((err, None)));
+ Ok(())
+ } else if !self.rx_closed {
+ self.rx.close();
+ if let Some((req, cb)) = self.rx.try_recv() {
+ trace!("canceling queued request with connection error: {}", err);
+ // in this case, the message was never even started, so it's safe to tell
+ // the user that the request was completely canceled
+ cb.send(Err((crate::Error::new_canceled().with(err), Some(req))));
+ Ok(())
+ } else {
+ Err(err)
+ }
+ } else {
+ Err(err)
+ }
+ }
+ }
+ }
+
+ fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> {
+ match self.callback {
+ Some(ref mut cb) => match cb.poll_canceled(cx) {
+ Poll::Ready(()) => {
+ trace!("callback receiver has dropped");
+ Poll::Ready(Err(()))
+ }
+ Poll::Pending => Poll::Ready(Ok(())),
+ },
+ None => Poll::Ready(Err(())),
+ }
+ }
+
+ fn should_poll(&self) -> bool {
+ self.callback.is_none()
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::proto::h1::ClientTransaction;
+ use std::time::Duration;
+
+ #[test]
+ fn client_read_bytes_before_writing_request() {
+ let _ = pretty_env_logger::try_init();
+
+ tokio_test::task::spawn(()).enter(|cx, _| {
+ let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
+
+ // Block at 0 for now, but we will release this response before
+ // the request is ready to write later...
+ let (mut tx, rx) = crate::client::dispatch::channel();
+ let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
+ let mut dispatcher = Dispatcher::new(Client::new(rx), conn);
+
+ // First poll is needed to allow tx to send...
+ assert!(Pin::new(&mut dispatcher).poll(cx).is_pending());
+
+ // Unblock our IO, which has a response before we've sent request!
+ //
+ handle.read(b"HTTP/1.1 200 OK\r\n\r\n");
+
+ let mut res_rx = tx
+ .try_send(crate::Request::new(crate::Body::empty()))
+ .unwrap();
+
+ tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx));
+ let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx))
+ .expect_err("callback should send error");
+
+ match (err.0.kind(), err.1) {
+ (&crate::error::Kind::Canceled, Some(_)) => (),
+ other => panic!("expected Canceled, got {:?}", other),
+ }
+ });
+ }
+
+ #[tokio::test]
+ async fn client_flushing_is_not_ready_for_next_request() {
+ let _ = pretty_env_logger::try_init();
+
+ let (io, _handle) = tokio_test::io::Builder::new()
+ .write(b"POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\n")
+ .read(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n")
+ .wait(std::time::Duration::from_secs(2))
+ .build_with_handle();
+
+ let (mut tx, rx) = crate::client::dispatch::channel();
+ let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
+ conn.set_write_strategy_queue();
+
+ let dispatcher = Dispatcher::new(Client::new(rx), conn);
+ let _dispatcher = tokio::spawn(async move { dispatcher.await });
+
+ let req = crate::Request::builder()
+ .method("POST")
+ .body(crate::Body::from("reee"))
+ .unwrap();
+
+ let res = tx.try_send(req).unwrap().await.expect("response");
+ drop(res);
+
+ assert!(!tx.is_ready());
+ }
+
+ #[tokio::test]
+ async fn body_empty_chunks_ignored() {
+ let _ = pretty_env_logger::try_init();
+
+ let io = tokio_test::io::Builder::new()
+ // no reading or writing, just be blocked for the test...
+ .wait(Duration::from_secs(5))
+ .build();
+
+ let (mut tx, rx) = crate::client::dispatch::channel();
+ let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
+ let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn));
+
+ // First poll is needed to allow tx to send...
+ assert!(dispatcher.poll().is_pending());
+
+ let body = {
+ let (mut tx, body) = crate::Body::channel();
+ tx.try_send_data("".into()).unwrap();
+ body
+ };
+
+ let _res_rx = tx.try_send(crate::Request::new(body)).unwrap();
+
+ // Ensure conn.write_body wasn't called with the empty chunk.
+ // If it is, it will trigger an assertion.
+ assert!(dispatcher.poll().is_pending());
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h1/encode.rs b/third_party/rust/hyper/src/proto/h1/encode.rs
new file mode 100644
index 0000000000..f0aa261a4f
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/encode.rs
@@ -0,0 +1,439 @@
+use std::fmt;
+use std::io::IoSlice;
+
+use bytes::buf::{Chain, Take};
+use bytes::Buf;
+use tracing::trace;
+
+use super::io::WriteBuf;
+
+type StaticBuf = &'static [u8];
+
+/// Encoders to handle different Transfer-Encodings.
+#[derive(Debug, Clone, PartialEq)]
+pub(crate) struct Encoder {
+ kind: Kind,
+ is_last: bool,
+}
+
+#[derive(Debug)]
+pub(crate) struct EncodedBuf<B> {
+ kind: BufKind<B>,
+}
+
+#[derive(Debug)]
+pub(crate) struct NotEof(u64);
+
+#[derive(Debug, PartialEq, Clone)]
+enum Kind {
+ /// An Encoder for when Transfer-Encoding includes `chunked`.
+ Chunked,
+ /// An Encoder for when Content-Length is set.
+ ///
+ /// Enforces that the body is not longer than the Content-Length header.
+ Length(u64),
+ /// An Encoder for when neither Content-Length nor Chunked encoding is set.
+ ///
+ /// This is mostly only used with HTTP/1.0 with a length. This kind requires
+ /// the connection to be closed when the body is finished.
+ #[cfg(feature = "server")]
+ CloseDelimited,
+}
+
+#[derive(Debug)]
+enum BufKind<B> {
+ Exact(B),
+ Limited(Take<B>),
+ Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
+ ChunkedEnd(StaticBuf),
+}
+
+impl Encoder {
+ fn new(kind: Kind) -> Encoder {
+ Encoder {
+ kind,
+ is_last: false,
+ }
+ }
+ pub(crate) fn chunked() -> Encoder {
+ Encoder::new(Kind::Chunked)
+ }
+
+ pub(crate) fn length(len: u64) -> Encoder {
+ Encoder::new(Kind::Length(len))
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn close_delimited() -> Encoder {
+ Encoder::new(Kind::CloseDelimited)
+ }
+
+ pub(crate) fn is_eof(&self) -> bool {
+ matches!(self.kind, Kind::Length(0))
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn set_last(mut self, is_last: bool) -> Self {
+ self.is_last = is_last;
+ self
+ }
+
+ pub(crate) fn is_last(&self) -> bool {
+ self.is_last
+ }
+
+ pub(crate) fn is_close_delimited(&self) -> bool {
+ match self.kind {
+ #[cfg(feature = "server")]
+ Kind::CloseDelimited => true,
+ _ => false,
+ }
+ }
+
+ pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
+ match self.kind {
+ Kind::Length(0) => Ok(None),
+ Kind::Chunked => Ok(Some(EncodedBuf {
+ kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
+ })),
+ #[cfg(feature = "server")]
+ Kind::CloseDelimited => Ok(None),
+ Kind::Length(n) => Err(NotEof(n)),
+ }
+ }
+
+ pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
+ where
+ B: Buf,
+ {
+ let len = msg.remaining();
+ debug_assert!(len > 0, "encode() called with empty buf");
+
+ let kind = match self.kind {
+ Kind::Chunked => {
+ trace!("encoding chunked {}B", len);
+ let buf = ChunkSize::new(len)
+ .chain(msg)
+ .chain(b"\r\n" as &'static [u8]);
+ BufKind::Chunked(buf)
+ }
+ Kind::Length(ref mut remaining) => {
+ trace!("sized write, len = {}", len);
+ if len as u64 > *remaining {
+ let limit = *remaining as usize;
+ *remaining = 0;
+ BufKind::Limited(msg.take(limit))
+ } else {
+ *remaining -= len as u64;
+ BufKind::Exact(msg)
+ }
+ }
+ #[cfg(feature = "server")]
+ Kind::CloseDelimited => {
+ trace!("close delimited write {}B", len);
+ BufKind::Exact(msg)
+ }
+ };
+ EncodedBuf { kind }
+ }
+
+ pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
+ where
+ B: Buf,
+ {
+ let len = msg.remaining();
+ debug_assert!(len > 0, "encode() called with empty buf");
+
+ match self.kind {
+ Kind::Chunked => {
+ trace!("encoding chunked {}B", len);
+ let buf = ChunkSize::new(len)
+ .chain(msg)
+ .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
+ dst.buffer(buf);
+ !self.is_last
+ }
+ Kind::Length(remaining) => {
+ use std::cmp::Ordering;
+
+ trace!("sized write, len = {}", len);
+ match (len as u64).cmp(&remaining) {
+ Ordering::Equal => {
+ dst.buffer(msg);
+ !self.is_last
+ }
+ Ordering::Greater => {
+ dst.buffer(msg.take(remaining as usize));
+ !self.is_last
+ }
+ Ordering::Less => {
+ dst.buffer(msg);
+ false
+ }
+ }
+ }
+ #[cfg(feature = "server")]
+ Kind::CloseDelimited => {
+ trace!("close delimited write {}B", len);
+ dst.buffer(msg);
+ false
+ }
+ }
+ }
+
+ /// Encodes the full body, without verifying the remaining length matches.
+ ///
+ /// This is used in conjunction with HttpBody::__hyper_full_data(), which
+ /// means we can trust that the buf has the correct size (the buf itself
+ /// was checked to make the headers).
+ pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>)
+ where
+ B: Buf,
+ {
+ debug_assert!(msg.remaining() > 0, "encode() called with empty buf");
+ debug_assert!(
+ match self.kind {
+ Kind::Length(len) => len == msg.remaining() as u64,
+ _ => true,
+ },
+ "danger_full_buf length mismatches"
+ );
+
+ match self.kind {
+ Kind::Chunked => {
+ let len = msg.remaining();
+ trace!("encoding chunked {}B", len);
+ let buf = ChunkSize::new(len)
+ .chain(msg)
+ .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
+ dst.buffer(buf);
+ }
+ _ => {
+ dst.buffer(msg);
+ }
+ }
+ }
+}
+
+impl<B> Buf for EncodedBuf<B>
+where
+ B: Buf,
+{
+ #[inline]
+ fn remaining(&self) -> usize {
+ match self.kind {
+ BufKind::Exact(ref b) => b.remaining(),
+ BufKind::Limited(ref b) => b.remaining(),
+ BufKind::Chunked(ref b) => b.remaining(),
+ BufKind::ChunkedEnd(ref b) => b.remaining(),
+ }
+ }
+
+ #[inline]
+ fn chunk(&self) -> &[u8] {
+ match self.kind {
+ BufKind::Exact(ref b) => b.chunk(),
+ BufKind::Limited(ref b) => b.chunk(),
+ BufKind::Chunked(ref b) => b.chunk(),
+ BufKind::ChunkedEnd(ref b) => b.chunk(),
+ }
+ }
+
+ #[inline]
+ fn advance(&mut self, cnt: usize) {
+ match self.kind {
+ BufKind::Exact(ref mut b) => b.advance(cnt),
+ BufKind::Limited(ref mut b) => b.advance(cnt),
+ BufKind::Chunked(ref mut b) => b.advance(cnt),
+ BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
+ }
+ }
+
+ #[inline]
+ fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
+ match self.kind {
+ BufKind::Exact(ref b) => b.chunks_vectored(dst),
+ BufKind::Limited(ref b) => b.chunks_vectored(dst),
+ BufKind::Chunked(ref b) => b.chunks_vectored(dst),
+ BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
+ }
+ }
+}
+
+#[cfg(target_pointer_width = "32")]
+const USIZE_BYTES: usize = 4;
+
+#[cfg(target_pointer_width = "64")]
+const USIZE_BYTES: usize = 8;
+
+// each byte will become 2 hex
+const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
+
+#[derive(Clone, Copy)]
+struct ChunkSize {
+ bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
+ pos: u8,
+ len: u8,
+}
+
+impl ChunkSize {
+ fn new(len: usize) -> ChunkSize {
+ use std::fmt::Write;
+ let mut size = ChunkSize {
+ bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
+ pos: 0,
+ len: 0,
+ };
+ write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
+ size
+ }
+}
+
+impl Buf for ChunkSize {
+ #[inline]
+ fn remaining(&self) -> usize {
+ (self.len - self.pos).into()
+ }
+
+ #[inline]
+ fn chunk(&self) -> &[u8] {
+ &self.bytes[self.pos.into()..self.len.into()]
+ }
+
+ #[inline]
+ fn advance(&mut self, cnt: usize) {
+ assert!(cnt <= self.remaining());
+ self.pos += cnt as u8; // just asserted cnt fits in u8
+ }
+}
+
+impl fmt::Debug for ChunkSize {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ChunkSize")
+ .field("bytes", &&self.bytes[..self.len.into()])
+ .field("pos", &self.pos)
+ .finish()
+ }
+}
+
+impl fmt::Write for ChunkSize {
+ fn write_str(&mut self, num: &str) -> fmt::Result {
+ use std::io::Write;
+ (&mut self.bytes[self.len.into()..])
+ .write_all(num.as_bytes())
+ .expect("&mut [u8].write() cannot error");
+ self.len += num.len() as u8; // safe because bytes is never bigger than 256
+ Ok(())
+ }
+}
+
+impl<B: Buf> From<B> for EncodedBuf<B> {
+ fn from(buf: B) -> Self {
+ EncodedBuf {
+ kind: BufKind::Exact(buf),
+ }
+ }
+}
+
+impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
+ fn from(buf: Take<B>) -> Self {
+ EncodedBuf {
+ kind: BufKind::Limited(buf),
+ }
+ }
+}
+
+impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
+ fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
+ EncodedBuf {
+ kind: BufKind::Chunked(buf),
+ }
+ }
+}
+
+impl fmt::Display for NotEof {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "early end, expected {} more bytes", self.0)
+ }
+}
+
+impl std::error::Error for NotEof {}
+
+#[cfg(test)]
+mod tests {
+ use bytes::BufMut;
+
+ use super::super::io::Cursor;
+ use super::Encoder;
+
+ #[test]
+ fn chunked() {
+ let mut encoder = Encoder::chunked();
+ let mut dst = Vec::new();
+
+ let msg1 = b"foo bar".as_ref();
+ let buf1 = encoder.encode(msg1);
+ dst.put(buf1);
+ assert_eq!(dst, b"7\r\nfoo bar\r\n");
+
+ let msg2 = b"baz quux herp".as_ref();
+ let buf2 = encoder.encode(msg2);
+ dst.put(buf2);
+
+ assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
+
+ let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
+ dst.put(end);
+
+ assert_eq!(
+ dst,
+ b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
+ );
+ }
+
+ #[test]
+ fn length() {
+ let max_len = 8;
+ let mut encoder = Encoder::length(max_len as u64);
+ let mut dst = Vec::new();
+
+ let msg1 = b"foo bar".as_ref();
+ let buf1 = encoder.encode(msg1);
+ dst.put(buf1);
+
+ assert_eq!(dst, b"foo bar");
+ assert!(!encoder.is_eof());
+ encoder.end::<()>().unwrap_err();
+
+ let msg2 = b"baz".as_ref();
+ let buf2 = encoder.encode(msg2);
+ dst.put(buf2);
+
+ assert_eq!(dst.len(), max_len);
+ assert_eq!(dst, b"foo barb");
+ assert!(encoder.is_eof());
+ assert!(encoder.end::<()>().unwrap().is_none());
+ }
+
+ #[test]
+ fn eof() {
+ let mut encoder = Encoder::close_delimited();
+ let mut dst = Vec::new();
+
+ let msg1 = b"foo bar".as_ref();
+ let buf1 = encoder.encode(msg1);
+ dst.put(buf1);
+
+ assert_eq!(dst, b"foo bar");
+ assert!(!encoder.is_eof());
+ encoder.end::<()>().unwrap();
+
+ let msg2 = b"baz".as_ref();
+ let buf2 = encoder.encode(msg2);
+ dst.put(buf2);
+
+ assert_eq!(dst, b"foo barbaz");
+ assert!(!encoder.is_eof());
+ encoder.end::<()>().unwrap();
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h1/io.rs b/third_party/rust/hyper/src/proto/h1/io.rs
new file mode 100644
index 0000000000..1d251e2c84
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/io.rs
@@ -0,0 +1,1002 @@
+use std::cmp;
+use std::fmt;
+#[cfg(all(feature = "server", feature = "runtime"))]
+use std::future::Future;
+use std::io::{self, IoSlice};
+use std::marker::Unpin;
+use std::mem::MaybeUninit;
+#[cfg(all(feature = "server", feature = "runtime"))]
+use std::time::Duration;
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+#[cfg(all(feature = "server", feature = "runtime"))]
+use tokio::time::Instant;
+use tracing::{debug, trace};
+
+use super::{Http1Transaction, ParseContext, ParsedMessage};
+use crate::common::buf::BufList;
+use crate::common::{task, Pin, Poll};
+
+/// The initial buffer size allocated before trying to read from IO.
+pub(crate) const INIT_BUFFER_SIZE: usize = 8192;
+
+/// The minimum value that can be set to max buffer size.
+pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE;
+
+/// The default maximum read buffer size. If the buffer gets this big and
+/// a message is still not complete, a `TooLarge` error is triggered.
+// Note: if this changes, update server::conn::Http::max_buf_size docs.
+pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
+
+/// The maximum number of distinct `Buf`s to hold in a list before requiring
+/// a flush. Only affects when the buffer strategy is to queue buffers.
+///
+/// Note that a flush can happen before reaching the maximum. This simply
+/// forces a flush if the queue gets this big.
+const MAX_BUF_LIST_BUFFERS: usize = 16;
+
+pub(crate) struct Buffered<T, B> {
+ flush_pipeline: bool,
+ io: T,
+ read_blocked: bool,
+ read_buf: BytesMut,
+ read_buf_strategy: ReadStrategy,
+ write_buf: WriteBuf<B>,
+}
+
+impl<T, B> fmt::Debug for Buffered<T, B>
+where
+ B: Buf,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Buffered")
+ .field("read_buf", &self.read_buf)
+ .field("write_buf", &self.write_buf)
+ .finish()
+ }
+}
+
+impl<T, B> Buffered<T, B>
+where
+ T: AsyncRead + AsyncWrite + Unpin,
+ B: Buf,
+{
+ pub(crate) fn new(io: T) -> Buffered<T, B> {
+ let strategy = if io.is_write_vectored() {
+ WriteStrategy::Queue
+ } else {
+ WriteStrategy::Flatten
+ };
+ let write_buf = WriteBuf::new(strategy);
+ Buffered {
+ flush_pipeline: false,
+ io,
+ read_blocked: false,
+ read_buf: BytesMut::with_capacity(0),
+ read_buf_strategy: ReadStrategy::default(),
+ write_buf,
+ }
+ }
+
+ #[cfg(feature = "server")]
+ pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
+ debug_assert!(!self.write_buf.has_remaining());
+ self.flush_pipeline = enabled;
+ if enabled {
+ self.set_write_strategy_flatten();
+ }
+ }
+
+ pub(crate) fn set_max_buf_size(&mut self, max: usize) {
+ assert!(
+ max >= MINIMUM_MAX_BUFFER_SIZE,
+ "The max_buf_size cannot be smaller than {}.",
+ MINIMUM_MAX_BUFFER_SIZE,
+ );
+ self.read_buf_strategy = ReadStrategy::with_max(max);
+ self.write_buf.max_buf_size = max;
+ }
+
+ #[cfg(feature = "client")]
+ pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
+ self.read_buf_strategy = ReadStrategy::Exact(sz);
+ }
+
+ pub(crate) fn set_write_strategy_flatten(&mut self) {
+ // this should always be called only at construction time,
+ // so this assert is here to catch myself
+ debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
+ self.write_buf.set_strategy(WriteStrategy::Flatten);
+ }
+
+ pub(crate) fn set_write_strategy_queue(&mut self) {
+ // this should always be called only at construction time,
+ // so this assert is here to catch myself
+ debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
+ self.write_buf.set_strategy(WriteStrategy::Queue);
+ }
+
+ pub(crate) fn read_buf(&self) -> &[u8] {
+ self.read_buf.as_ref()
+ }
+
+ #[cfg(test)]
+ #[cfg(feature = "nightly")]
+ pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut {
+ &mut self.read_buf
+ }
+
+ /// Return the "allocated" available space, not the potential space
+ /// that could be allocated in the future.
+ fn read_buf_remaining_mut(&self) -> usize {
+ self.read_buf.capacity() - self.read_buf.len()
+ }
+
+ /// Return whether we can append to the headers buffer.
+ ///
+ /// Reasons we can't:
+ /// - The write buf is in queue mode, and some of the past body is still
+ /// needing to be flushed.
+ pub(crate) fn can_headers_buf(&self) -> bool {
+ !self.write_buf.queue.has_remaining()
+ }
+
+ pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> {
+ let buf = self.write_buf.headers_mut();
+ &mut buf.bytes
+ }
+
+ pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> {
+ &mut self.write_buf
+ }
+
+ pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) {
+ self.write_buf.buffer(buf)
+ }
+
+ pub(crate) fn can_buffer(&self) -> bool {
+ self.flush_pipeline || self.write_buf.can_buffer()
+ }
+
+ pub(crate) fn consume_leading_lines(&mut self) {
+ if !self.read_buf.is_empty() {
+ let mut i = 0;
+ while i < self.read_buf.len() {
+ match self.read_buf[i] {
+ b'\r' | b'\n' => i += 1,
+ _ => break,
+ }
+ }
+ self.read_buf.advance(i);
+ }
+ }
+
+ pub(super) fn parse<S>(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ parse_ctx: ParseContext<'_>,
+ ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>>
+ where
+ S: Http1Transaction,
+ {
+ loop {
+ match super::role::parse_headers::<S>(
+ &mut self.read_buf,
+ ParseContext {
+ cached_headers: parse_ctx.cached_headers,
+ req_method: parse_ctx.req_method,
+ h1_parser_config: parse_ctx.h1_parser_config.clone(),
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
+ preserve_header_case: parse_ctx.preserve_header_case,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: parse_ctx.preserve_header_order,
+ h09_responses: parse_ctx.h09_responses,
+ #[cfg(feature = "ffi")]
+ on_informational: parse_ctx.on_informational,
+ #[cfg(feature = "ffi")]
+ raw_headers: parse_ctx.raw_headers,
+ },
+ )? {
+ Some(msg) => {
+ debug!("parsed {} headers", msg.head.headers.len());
+
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ {
+ *parse_ctx.h1_header_read_timeout_running = false;
+
+ if let Some(h1_header_read_timeout_fut) =
+ parse_ctx.h1_header_read_timeout_fut
+ {
+ // Reset the timer in order to avoid woken up when the timeout finishes
+ h1_header_read_timeout_fut
+ .as_mut()
+ .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60));
+ }
+ }
+ return Poll::Ready(Ok(msg));
+ }
+ None => {
+ let max = self.read_buf_strategy.max();
+ if self.read_buf.len() >= max {
+ debug!("max_buf_size ({}) reached, closing", max);
+ return Poll::Ready(Err(crate::Error::new_too_large()));
+ }
+
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ if *parse_ctx.h1_header_read_timeout_running {
+ if let Some(h1_header_read_timeout_fut) =
+ parse_ctx.h1_header_read_timeout_fut
+ {
+ if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
+ *parse_ctx.h1_header_read_timeout_running = false;
+
+ tracing::warn!("read header from client timeout");
+ return Poll::Ready(Err(crate::Error::new_header_timeout()));
+ }
+ }
+ }
+ }
+ }
+ if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
+ trace!("parse eof");
+ return Poll::Ready(Err(crate::Error::new_incomplete()));
+ }
+ }
+ }
+
+ pub(crate) fn poll_read_from_io(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<io::Result<usize>> {
+ self.read_blocked = false;
+ let next = self.read_buf_strategy.next();
+ if self.read_buf_remaining_mut() < next {
+ self.read_buf.reserve(next);
+ }
+
+ let dst = self.read_buf.chunk_mut();
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
+ Poll::Ready(Ok(_)) => {
+ let n = buf.filled().len();
+ trace!("received {} bytes", n);
+ unsafe {
+ // Safety: we just read that many bytes into the
+ // uninitialized part of the buffer, so this is okay.
+ // @tokio pls give me back `poll_read_buf` thanks
+ self.read_buf.advance_mut(n);
+ }
+ self.read_buf_strategy.record(n);
+ Poll::Ready(Ok(n))
+ }
+ Poll::Pending => {
+ self.read_blocked = true;
+ Poll::Pending
+ }
+ Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
+ }
+ }
+
+ pub(crate) fn into_inner(self) -> (T, Bytes) {
+ (self.io, self.read_buf.freeze())
+ }
+
+ pub(crate) fn io_mut(&mut self) -> &mut T {
+ &mut self.io
+ }
+
+ pub(crate) fn is_read_blocked(&self) -> bool {
+ self.read_blocked
+ }
+
+ pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ if self.flush_pipeline && !self.read_buf.is_empty() {
+ Poll::Ready(Ok(()))
+ } else if self.write_buf.remaining() == 0 {
+ Pin::new(&mut self.io).poll_flush(cx)
+ } else {
+ if let WriteStrategy::Flatten = self.write_buf.strategy {
+ return self.poll_flush_flattened(cx);
+ }
+
+ const MAX_WRITEV_BUFS: usize = 64;
+ loop {
+ let n = {
+ let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS];
+ let len = self.write_buf.chunks_vectored(&mut iovs);
+ ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))?
+ };
+ // TODO(eliza): we have to do this manually because
+ // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
+ // `poll_write_buf` comes back, the manual advance will need to leave!
+ self.write_buf.advance(n);
+ debug!("flushed {} bytes", n);
+ if self.write_buf.remaining() == 0 {
+ break;
+ } else if n == 0 {
+ trace!(
+ "write returned zero, but {} bytes remaining",
+ self.write_buf.remaining()
+ );
+ return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
+ }
+ }
+ Pin::new(&mut self.io).poll_flush(cx)
+ }
+ }
+
+ /// Specialized version of `flush` when strategy is Flatten.
+ ///
+ /// Since all buffered bytes are flattened into the single headers buffer,
+ /// that skips some bookkeeping around using multiple buffers.
+ fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
+ loop {
+ let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?;
+ debug!("flushed {} bytes", n);
+ self.write_buf.headers.advance(n);
+ if self.write_buf.headers.remaining() == 0 {
+ self.write_buf.headers.reset();
+ break;
+ } else if n == 0 {
+ trace!(
+ "write returned zero, but {} bytes remaining",
+ self.write_buf.remaining()
+ );
+ return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
+ }
+ }
+ Pin::new(&mut self.io).poll_flush(cx)
+ }
+
+ #[cfg(test)]
+ fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a {
+ futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
+ }
+}
+
+// The `B` is a `Buf`, we never project a pin to it
+impl<T: Unpin, B> Unpin for Buffered<T, B> {}
+
+// TODO: This trait is old... at least rename to PollBytes or something...
+pub(crate) trait MemRead {
+ fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>;
+}
+
+impl<T, B> MemRead for Buffered<T, B>
+where
+ T: AsyncRead + AsyncWrite + Unpin,
+ B: Buf,
+{
+ fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
+ if !self.read_buf.is_empty() {
+ let n = std::cmp::min(len, self.read_buf.len());
+ Poll::Ready(Ok(self.read_buf.split_to(n).freeze()))
+ } else {
+ let n = ready!(self.poll_read_from_io(cx))?;
+ Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()))
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+enum ReadStrategy {
+ Adaptive {
+ decrease_now: bool,
+ next: usize,
+ max: usize,
+ },
+ #[cfg(feature = "client")]
+ Exact(usize),
+}
+
+impl ReadStrategy {
+ fn with_max(max: usize) -> ReadStrategy {
+ ReadStrategy::Adaptive {
+ decrease_now: false,
+ next: INIT_BUFFER_SIZE,
+ max,
+ }
+ }
+
+ fn next(&self) -> usize {
+ match *self {
+ ReadStrategy::Adaptive { next, .. } => next,
+ #[cfg(feature = "client")]
+ ReadStrategy::Exact(exact) => exact,
+ }
+ }
+
+ fn max(&self) -> usize {
+ match *self {
+ ReadStrategy::Adaptive { max, .. } => max,
+ #[cfg(feature = "client")]
+ ReadStrategy::Exact(exact) => exact,
+ }
+ }
+
+ fn record(&mut self, bytes_read: usize) {
+ match *self {
+ ReadStrategy::Adaptive {
+ ref mut decrease_now,
+ ref mut next,
+ max,
+ ..
+ } => {
+ if bytes_read >= *next {
+ *next = cmp::min(incr_power_of_two(*next), max);
+ *decrease_now = false;
+ } else {
+ let decr_to = prev_power_of_two(*next);
+ if bytes_read < decr_to {
+ if *decrease_now {
+ *next = cmp::max(decr_to, INIT_BUFFER_SIZE);
+ *decrease_now = false;
+ } else {
+ // Decreasing is a two "record" process.
+ *decrease_now = true;
+ }
+ } else {
+ // A read within the current range should cancel
+ // a potential decrease, since we just saw proof
+ // that we still need this size.
+ *decrease_now = false;
+ }
+ }
+ }
+ #[cfg(feature = "client")]
+ ReadStrategy::Exact(_) => (),
+ }
+ }
+}
+
+fn incr_power_of_two(n: usize) -> usize {
+ n.saturating_mul(2)
+}
+
+fn prev_power_of_two(n: usize) -> usize {
+ // Only way this shift can underflow is if n is less than 4.
+ // (Which would means `usize::MAX >> 64` and underflowed!)
+ debug_assert!(n >= 4);
+ (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1
+}
+
+impl Default for ReadStrategy {
+ fn default() -> ReadStrategy {
+ ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE)
+ }
+}
+
+#[derive(Clone)]
+pub(crate) struct Cursor<T> {
+ bytes: T,
+ pos: usize,
+}
+
+impl<T: AsRef<[u8]>> Cursor<T> {
+ #[inline]
+ pub(crate) fn new(bytes: T) -> Cursor<T> {
+ Cursor { bytes, pos: 0 }
+ }
+}
+
+impl Cursor<Vec<u8>> {
+ /// If we've advanced the position a bit in this cursor, and wish to
+ /// extend the underlying vector, we may wish to unshift the "read" bytes
+ /// off, and move everything else over.
+ fn maybe_unshift(&mut self, additional: usize) {
+ if self.pos == 0 {
+ // nothing to do
+ return;
+ }
+
+ if self.bytes.capacity() - self.bytes.len() >= additional {
+ // there's room!
+ return;
+ }
+
+ self.bytes.drain(0..self.pos);
+ self.pos = 0;
+ }
+
+ fn reset(&mut self) {
+ self.pos = 0;
+ self.bytes.clear();
+ }
+}
+
+impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Cursor")
+ .field("pos", &self.pos)
+ .field("len", &self.bytes.as_ref().len())
+ .finish()
+ }
+}
+
+impl<T: AsRef<[u8]>> Buf for Cursor<T> {
+ #[inline]
+ fn remaining(&self) -> usize {
+ self.bytes.as_ref().len() - self.pos
+ }
+
+ #[inline]
+ fn chunk(&self) -> &[u8] {
+ &self.bytes.as_ref()[self.pos..]
+ }
+
+ #[inline]
+ fn advance(&mut self, cnt: usize) {
+ debug_assert!(self.pos + cnt <= self.bytes.as_ref().len());
+ self.pos += cnt;
+ }
+}
+
+// an internal buffer to collect writes before flushes
+pub(super) struct WriteBuf<B> {
+ /// Re-usable buffer that holds message headers
+ headers: Cursor<Vec<u8>>,
+ max_buf_size: usize,
+ /// Deque of user buffers if strategy is Queue
+ queue: BufList<B>,
+ strategy: WriteStrategy,
+}
+
+impl<B: Buf> WriteBuf<B> {
+ fn new(strategy: WriteStrategy) -> WriteBuf<B> {
+ WriteBuf {
+ headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)),
+ max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
+ queue: BufList::new(),
+ strategy,
+ }
+ }
+}
+
+impl<B> WriteBuf<B>
+where
+ B: Buf,
+{
+ fn set_strategy(&mut self, strategy: WriteStrategy) {
+ self.strategy = strategy;
+ }
+
+ pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) {
+ debug_assert!(buf.has_remaining());
+ match self.strategy {
+ WriteStrategy::Flatten => {
+ let head = self.headers_mut();
+
+ head.maybe_unshift(buf.remaining());
+ trace!(
+ self.len = head.remaining(),
+ buf.len = buf.remaining(),
+ "buffer.flatten"
+ );
+ //perf: This is a little faster than <Vec as BufMut>>::put,
+ //but accomplishes the same result.
+ loop {
+ let adv = {
+ let slice = buf.chunk();
+ if slice.is_empty() {
+ return;
+ }
+ head.bytes.extend_from_slice(slice);
+ slice.len()
+ };
+ buf.advance(adv);
+ }
+ }
+ WriteStrategy::Queue => {
+ trace!(
+ self.len = self.remaining(),
+ buf.len = buf.remaining(),
+ "buffer.queue"
+ );
+ self.queue.push(buf.into());
+ }
+ }
+ }
+
+ fn can_buffer(&self) -> bool {
+ match self.strategy {
+ WriteStrategy::Flatten => self.remaining() < self.max_buf_size,
+ WriteStrategy::Queue => {
+ self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size
+ }
+ }
+ }
+
+ fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> {
+ debug_assert!(!self.queue.has_remaining());
+ &mut self.headers
+ }
+}
+
+impl<B: Buf> fmt::Debug for WriteBuf<B> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("WriteBuf")
+ .field("remaining", &self.remaining())
+ .field("strategy", &self.strategy)
+ .finish()
+ }
+}
+
+impl<B: Buf> Buf for WriteBuf<B> {
+ #[inline]
+ fn remaining(&self) -> usize {
+ self.headers.remaining() + self.queue.remaining()
+ }
+
+ #[inline]
+ fn chunk(&self) -> &[u8] {
+ let headers = self.headers.chunk();
+ if !headers.is_empty() {
+ headers
+ } else {
+ self.queue.chunk()
+ }
+ }
+
+ #[inline]
+ fn advance(&mut self, cnt: usize) {
+ let hrem = self.headers.remaining();
+
+ match hrem.cmp(&cnt) {
+ cmp::Ordering::Equal => self.headers.reset(),
+ cmp::Ordering::Greater => self.headers.advance(cnt),
+ cmp::Ordering::Less => {
+ let qcnt = cnt - hrem;
+ self.headers.reset();
+ self.queue.advance(qcnt);
+ }
+ }
+ }
+
+ #[inline]
+ fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
+ let n = self.headers.chunks_vectored(dst);
+ self.queue.chunks_vectored(&mut dst[n..]) + n
+ }
+}
+
+#[derive(Debug)]
+enum WriteStrategy {
+ Flatten,
+ Queue,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::time::Duration;
+
+ use tokio_test::io::Builder as Mock;
+
+ // #[cfg(feature = "nightly")]
+ // use test::Bencher;
+
+ /*
+ impl<T: Read> MemRead for AsyncIo<T> {
+ fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
+ let mut v = vec![0; len];
+ let n = try_nb!(self.read(v.as_mut_slice()));
+ Ok(Async::Ready(BytesMut::from(&v[..n]).freeze()))
+ }
+ }
+ */
+
+ #[tokio::test]
+ #[ignore]
+ async fn iobuf_write_empty_slice() {
+ // TODO(eliza): can i have writev back pls T_T
+ // // First, let's just check that the Mock would normally return an
+ // // error on an unexpected write, even if the buffer is empty...
+ // let mut mock = Mock::new().build();
+ // futures_util::future::poll_fn(|cx| {
+ // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
+ // })
+ // .await
+ // .expect_err("should be a broken pipe");
+
+ // // underlying io will return the logic error upon write,
+ // // so we are testing that the io_buf does not trigger a write
+ // // when there is nothing to flush
+ // let mock = Mock::new().build();
+ // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+ // io_buf.flush().await.expect("should short-circuit flush");
+ }
+
+ #[tokio::test]
+ async fn parse_reads_until_blocked() {
+ use crate::proto::h1::ClientTransaction;
+
+ let _ = pretty_env_logger::try_init();
+ let mock = Mock::new()
+ // Split over multiple reads will read all of it
+ .read(b"HTTP/1.1 200 OK\r\n")
+ .read(b"Server: hyper\r\n")
+ // missing last line ending
+ .wait(Duration::from_secs(1))
+ .build();
+
+ let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+
+ // We expect a `parse` to be not ready, and so can't await it directly.
+ // Rather, this `poll_fn` will wrap the `Poll` result.
+ futures_util::future::poll_fn(|cx| {
+ let parse_ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ assert!(buffered
+ .parse::<ClientTransaction>(cx, parse_ctx)
+ .is_pending());
+ Poll::Ready(())
+ })
+ .await;
+
+ assert_eq!(
+ buffered.read_buf,
+ b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..]
+ );
+ }
+
+ #[test]
+ fn read_strategy_adaptive_increments() {
+ let mut strategy = ReadStrategy::default();
+ assert_eq!(strategy.next(), 8192);
+
+ // Grows if record == next
+ strategy.record(8192);
+ assert_eq!(strategy.next(), 16384);
+
+ strategy.record(16384);
+ assert_eq!(strategy.next(), 32768);
+
+ // Enormous records still increment at same rate
+ strategy.record(::std::usize::MAX);
+ assert_eq!(strategy.next(), 65536);
+
+ let max = strategy.max();
+ while strategy.next() < max {
+ strategy.record(max);
+ }
+
+ assert_eq!(strategy.next(), max, "never goes over max");
+ strategy.record(max + 1);
+ assert_eq!(strategy.next(), max, "never goes over max");
+ }
+
+ #[test]
+ fn read_strategy_adaptive_decrements() {
+ let mut strategy = ReadStrategy::default();
+ strategy.record(8192);
+ assert_eq!(strategy.next(), 16384);
+
+ strategy.record(1);
+ assert_eq!(
+ strategy.next(),
+ 16384,
+ "first smaller record doesn't decrement yet"
+ );
+ strategy.record(8192);
+ assert_eq!(strategy.next(), 16384, "record was with range");
+
+ strategy.record(1);
+ assert_eq!(
+ strategy.next(),
+ 16384,
+ "in-range record should make this the 'first' again"
+ );
+
+ strategy.record(1);
+ assert_eq!(strategy.next(), 8192, "second smaller record decrements");
+
+ strategy.record(1);
+ assert_eq!(strategy.next(), 8192, "first doesn't decrement");
+ strategy.record(1);
+ assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum");
+ }
+
+ #[test]
+ fn read_strategy_adaptive_stays_the_same() {
+ let mut strategy = ReadStrategy::default();
+ strategy.record(8192);
+ assert_eq!(strategy.next(), 16384);
+
+ strategy.record(8193);
+ assert_eq!(
+ strategy.next(),
+ 16384,
+ "first smaller record doesn't decrement yet"
+ );
+
+ strategy.record(8193);
+ assert_eq!(
+ strategy.next(),
+ 16384,
+ "with current step does not decrement"
+ );
+ }
+
+ #[test]
+ fn read_strategy_adaptive_max_fuzz() {
+ fn fuzz(max: usize) {
+ let mut strategy = ReadStrategy::with_max(max);
+ while strategy.next() < max {
+ strategy.record(::std::usize::MAX);
+ }
+ let mut next = strategy.next();
+ while next > 8192 {
+ strategy.record(1);
+ strategy.record(1);
+ next = strategy.next();
+ assert!(
+ next.is_power_of_two(),
+ "decrement should be powers of two: {} (max = {})",
+ next,
+ max,
+ );
+ }
+ }
+
+ let mut max = 8192;
+ while max < std::usize::MAX {
+ fuzz(max);
+ max = (max / 2).saturating_mul(3);
+ }
+ fuzz(::std::usize::MAX);
+ }
+
+ #[test]
+ #[should_panic]
+ #[cfg(debug_assertions)] // needs to trigger a debug_assert
+ fn write_buf_requires_non_empty_bufs() {
+ let mock = Mock::new().build();
+ let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+
+ buffered.buffer(Cursor::new(Vec::new()));
+ }
+
+ /*
+ TODO: needs tokio_test::io to allow configure write_buf calls
+ #[test]
+ fn write_buf_queue() {
+ let _ = pretty_env_logger::try_init();
+
+ let mock = AsyncIo::new_buf(vec![], 1024);
+ let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+
+
+ buffered.headers_buf().extend(b"hello ");
+ buffered.buffer(Cursor::new(b"world, ".to_vec()));
+ buffered.buffer(Cursor::new(b"it's ".to_vec()));
+ buffered.buffer(Cursor::new(b"hyper!".to_vec()));
+ assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
+ buffered.flush().unwrap();
+
+ assert_eq!(buffered.io, b"hello world, it's hyper!");
+ assert_eq!(buffered.io.num_writes(), 1);
+ assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
+ }
+ */
+
+ #[tokio::test]
+ async fn write_buf_flatten() {
+ let _ = pretty_env_logger::try_init();
+
+ let mock = Mock::new().write(b"hello world, it's hyper!").build();
+
+ let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+ buffered.write_buf.set_strategy(WriteStrategy::Flatten);
+
+ buffered.headers_buf().extend(b"hello ");
+ buffered.buffer(Cursor::new(b"world, ".to_vec()));
+ buffered.buffer(Cursor::new(b"it's ".to_vec()));
+ buffered.buffer(Cursor::new(b"hyper!".to_vec()));
+ assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
+
+ buffered.flush().await.expect("flush");
+ }
+
+ #[test]
+ fn write_buf_flatten_partially_flushed() {
+ let _ = pretty_env_logger::try_init();
+
+ let b = |s: &str| Cursor::new(s.as_bytes().to_vec());
+
+ let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten);
+
+ write_buf.buffer(b("hello "));
+ write_buf.buffer(b("world, "));
+
+ assert_eq!(write_buf.chunk(), b"hello world, ");
+
+ // advance most of the way, but not all
+ write_buf.advance(11);
+
+ assert_eq!(write_buf.chunk(), b", ");
+ assert_eq!(write_buf.headers.pos, 11);
+ assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE);
+
+ // there's still room in the headers buffer, so just push on the end
+ write_buf.buffer(b("it's hyper!"));
+
+ assert_eq!(write_buf.chunk(), b", it's hyper!");
+ assert_eq!(write_buf.headers.pos, 11);
+
+ let rem1 = write_buf.remaining();
+ let cap = write_buf.headers.bytes.capacity();
+
+ // but when this would go over capacity, don't copy the old bytes
+ write_buf.buffer(Cursor::new(vec![b'X'; cap]));
+ assert_eq!(write_buf.remaining(), cap + rem1);
+ assert_eq!(write_buf.headers.pos, 0);
+ }
+
+ #[tokio::test]
+ async fn write_buf_queue_disable_auto() {
+ let _ = pretty_env_logger::try_init();
+
+ let mock = Mock::new()
+ .write(b"hello ")
+ .write(b"world, ")
+ .write(b"it's ")
+ .write(b"hyper!")
+ .build();
+
+ let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
+ buffered.write_buf.set_strategy(WriteStrategy::Queue);
+
+ // we have 4 buffers, and vec IO disabled, but explicitly said
+ // don't try to auto detect (via setting strategy above)
+
+ buffered.headers_buf().extend(b"hello ");
+ buffered.buffer(Cursor::new(b"world, ".to_vec()));
+ buffered.buffer(Cursor::new(b"it's ".to_vec()));
+ buffered.buffer(Cursor::new(b"hyper!".to_vec()));
+ assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
+
+ buffered.flush().await.expect("flush");
+
+ assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
+ }
+
+ // #[cfg(feature = "nightly")]
+ // #[bench]
+ // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) {
+ // let s = "Hello, World!";
+ // b.bytes = s.len() as u64;
+
+ // let mut write_buf = WriteBuf::<bytes::Bytes>::new();
+ // write_buf.set_strategy(WriteStrategy::Flatten);
+ // b.iter(|| {
+ // let chunk = bytes::Bytes::from(s);
+ // write_buf.buffer(chunk);
+ // ::test::black_box(&write_buf);
+ // write_buf.headers.bytes.clear();
+ // })
+ // }
+}
diff --git a/third_party/rust/hyper/src/proto/h1/mod.rs b/third_party/rust/hyper/src/proto/h1/mod.rs
new file mode 100644
index 0000000000..5a2587a843
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/mod.rs
@@ -0,0 +1,122 @@
+#[cfg(all(feature = "server", feature = "runtime"))]
+use std::{pin::Pin, time::Duration};
+
+use bytes::BytesMut;
+use http::{HeaderMap, Method};
+use httparse::ParserConfig;
+#[cfg(all(feature = "server", feature = "runtime"))]
+use tokio::time::Sleep;
+
+use crate::body::DecodedLength;
+use crate::proto::{BodyLength, MessageHead};
+
+pub(crate) use self::conn::Conn;
+pub(crate) use self::decode::Decoder;
+pub(crate) use self::dispatch::Dispatcher;
+pub(crate) use self::encode::{EncodedBuf, Encoder};
+//TODO: move out of h1::io
+pub(crate) use self::io::MINIMUM_MAX_BUFFER_SIZE;
+
+mod conn;
+mod decode;
+pub(crate) mod dispatch;
+mod encode;
+mod io;
+mod role;
+
+cfg_client! {
+ pub(crate) type ClientTransaction = role::Client;
+}
+
+cfg_server! {
+ pub(crate) type ServerTransaction = role::Server;
+}
+
+pub(crate) trait Http1Transaction {
+ type Incoming;
+ type Outgoing: Default;
+ const LOG: &'static str;
+ fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>;
+ fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>;
+
+ fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>;
+
+ fn is_client() -> bool {
+ !Self::is_server()
+ }
+
+ fn is_server() -> bool {
+ !Self::is_client()
+ }
+
+ fn should_error_on_parse_eof() -> bool {
+ Self::is_client()
+ }
+
+ fn should_read_first() -> bool {
+ Self::is_server()
+ }
+
+ fn update_date() {}
+}
+
+/// Result newtype for Http1Transaction::parse.
+pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>;
+
+#[derive(Debug)]
+pub(crate) struct ParsedMessage<T> {
+ head: MessageHead<T>,
+ decode: DecodedLength,
+ expect_continue: bool,
+ keep_alive: bool,
+ wants_upgrade: bool,
+}
+
+pub(crate) struct ParseContext<'a> {
+ cached_headers: &'a mut Option<HeaderMap>,
+ req_method: &'a mut Option<Method>,
+ h1_parser_config: ParserConfig,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout: Option<Duration>,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_fut: &'a mut Option<Pin<Box<Sleep>>>,
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ h1_header_read_timeout_running: &'a mut bool,
+ preserve_header_case: bool,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: bool,
+ h09_responses: bool,
+ #[cfg(feature = "ffi")]
+ on_informational: &'a mut Option<crate::ffi::OnInformational>,
+ #[cfg(feature = "ffi")]
+ raw_headers: bool,
+}
+
+/// Passed to Http1Transaction::encode
+pub(crate) struct Encode<'a, T> {
+ head: &'a mut MessageHead<T>,
+ body: Option<BodyLength>,
+ #[cfg(feature = "server")]
+ keep_alive: bool,
+ req_method: &'a mut Option<Method>,
+ title_case_headers: bool,
+}
+
+/// Extra flags that a request "wants", like expect-continue or upgrades.
+#[derive(Clone, Copy, Debug)]
+struct Wants(u8);
+
+impl Wants {
+ const EMPTY: Wants = Wants(0b00);
+ const EXPECT: Wants = Wants(0b01);
+ const UPGRADE: Wants = Wants(0b10);
+
+ #[must_use]
+ fn add(self, other: Wants) -> Wants {
+ Wants(self.0 | other.0)
+ }
+
+ fn contains(&self, other: Wants) -> bool {
+ (self.0 & other.0) == other.0
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h1/role.rs b/third_party/rust/hyper/src/proto/h1/role.rs
new file mode 100644
index 0000000000..6252207baf
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h1/role.rs
@@ -0,0 +1,2847 @@
+use std::fmt::{self, Write};
+use std::mem::MaybeUninit;
+
+use bytes::Bytes;
+use bytes::BytesMut;
+#[cfg(feature = "server")]
+use http::header::ValueIter;
+use http::header::{self, Entry, HeaderName, HeaderValue};
+use http::{HeaderMap, Method, StatusCode, Version};
+#[cfg(all(feature = "server", feature = "runtime"))]
+use tokio::time::Instant;
+use tracing::{debug, error, trace, trace_span, warn};
+
+use crate::body::DecodedLength;
+#[cfg(feature = "server")]
+use crate::common::date;
+use crate::error::Parse;
+use crate::ext::HeaderCaseMap;
+#[cfg(feature = "ffi")]
+use crate::ext::OriginalHeaderOrder;
+use crate::headers;
+use crate::proto::h1::{
+ Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage,
+};
+use crate::proto::{BodyLength, MessageHead, RequestHead, RequestLine};
+
+const MAX_HEADERS: usize = 100;
+const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
+#[cfg(feature = "server")]
+const MAX_URI_LEN: usize = (u16::MAX - 1) as usize;
+
+macro_rules! header_name {
+ ($bytes:expr) => {{
+ {
+ match HeaderName::from_bytes($bytes) {
+ Ok(name) => name,
+ Err(e) => maybe_panic!(e),
+ }
+ }
+ }};
+}
+
+macro_rules! header_value {
+ ($bytes:expr) => {{
+ {
+ unsafe { HeaderValue::from_maybe_shared_unchecked($bytes) }
+ }
+ }};
+}
+
+macro_rules! maybe_panic {
+ ($($arg:tt)*) => ({
+ let _err = ($($arg)*);
+ if cfg!(debug_assertions) {
+ panic!("{:?}", _err);
+ } else {
+ error!("Internal Hyper error, please report {:?}", _err);
+ return Err(Parse::Internal)
+ }
+ })
+}
+
+pub(super) fn parse_headers<T>(
+ bytes: &mut BytesMut,
+ ctx: ParseContext<'_>,
+) -> ParseResult<T::Incoming>
+where
+ T: Http1Transaction,
+{
+ // If the buffer is empty, don't bother entering the span, it's just noise.
+ if bytes.is_empty() {
+ return Ok(None);
+ }
+
+ let span = trace_span!("parse_headers");
+ let _s = span.enter();
+
+ #[cfg(all(feature = "server", feature = "runtime"))]
+ if !*ctx.h1_header_read_timeout_running {
+ if let Some(h1_header_read_timeout) = ctx.h1_header_read_timeout {
+ let deadline = Instant::now() + h1_header_read_timeout;
+ *ctx.h1_header_read_timeout_running = true;
+ match ctx.h1_header_read_timeout_fut {
+ Some(h1_header_read_timeout_fut) => {
+ debug!("resetting h1 header read timeout timer");
+ h1_header_read_timeout_fut.as_mut().reset(deadline);
+ }
+ None => {
+ debug!("setting h1 header read timeout timer");
+ *ctx.h1_header_read_timeout_fut =
+ Some(Box::pin(tokio::time::sleep_until(deadline)));
+ }
+ }
+ }
+ }
+
+ T::parse(bytes, ctx)
+}
+
+pub(super) fn encode_headers<T>(
+ enc: Encode<'_, T::Outgoing>,
+ dst: &mut Vec<u8>,
+) -> crate::Result<Encoder>
+where
+ T: Http1Transaction,
+{
+ let span = trace_span!("encode_headers");
+ let _s = span.enter();
+ T::encode(enc, dst)
+}
+
+// There are 2 main roles, Client and Server.
+
+#[cfg(feature = "client")]
+pub(crate) enum Client {}
+
+#[cfg(feature = "server")]
+pub(crate) enum Server {}
+
+#[cfg(feature = "server")]
+impl Http1Transaction for Server {
+ type Incoming = RequestLine;
+ type Outgoing = StatusCode;
+ const LOG: &'static str = "{role=server}";
+
+ fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<RequestLine> {
+ debug_assert!(!buf.is_empty(), "parse called with empty buf");
+
+ let mut keep_alive;
+ let is_http_11;
+ let subject;
+ let version;
+ let len;
+ let headers_len;
+
+ // Unsafe: both headers_indices and headers are using uninitialized memory,
+ // but we *never* read any of it until after httparse has assigned
+ // values into it. By not zeroing out the stack memory, this saves
+ // a good ~5% on pipeline benchmarks.
+ let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe {
+ // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
+ MaybeUninit::uninit().assume_init()
+ };
+ {
+ /* SAFETY: it is safe to go from MaybeUninit array to array of MaybeUninit */
+ let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
+ unsafe { MaybeUninit::uninit().assume_init() };
+ trace!(bytes = buf.len(), "Request.parse");
+ let mut req = httparse::Request::new(&mut []);
+ let bytes = buf.as_ref();
+ match req.parse_with_uninit_headers(bytes, &mut headers) {
+ Ok(httparse::Status::Complete(parsed_len)) => {
+ trace!("Request.parse Complete({})", parsed_len);
+ len = parsed_len;
+ let uri = req.path.unwrap();
+ if uri.len() > MAX_URI_LEN {
+ return Err(Parse::UriTooLong);
+ }
+ subject = RequestLine(
+ Method::from_bytes(req.method.unwrap().as_bytes())?,
+ uri.parse()?,
+ );
+ version = if req.version.unwrap() == 1 {
+ keep_alive = true;
+ is_http_11 = true;
+ Version::HTTP_11
+ } else {
+ keep_alive = false;
+ is_http_11 = false;
+ Version::HTTP_10
+ };
+
+ record_header_indices(bytes, &req.headers, &mut headers_indices)?;
+ headers_len = req.headers.len();
+ }
+ Ok(httparse::Status::Partial) => return Ok(None),
+ Err(err) => {
+ return Err(match err {
+ // if invalid Token, try to determine if for method or path
+ httparse::Error::Token => {
+ if req.method.is_none() {
+ Parse::Method
+ } else {
+ debug_assert!(req.path.is_none());
+ Parse::Uri
+ }
+ }
+ other => other.into(),
+ });
+ }
+ }
+ };
+
+ let slice = buf.split_to(len).freeze();
+
+ // According to https://tools.ietf.org/html/rfc7230#section-3.3.3
+ // 1. (irrelevant to Request)
+ // 2. (irrelevant to Request)
+ // 3. Transfer-Encoding: chunked has a chunked body.
+ // 4. If multiple differing Content-Length headers or invalid, close connection.
+ // 5. Content-Length header has a sized body.
+ // 6. Length 0.
+ // 7. (irrelevant to Request)
+
+ let mut decoder = DecodedLength::ZERO;
+ let mut expect_continue = false;
+ let mut con_len = None;
+ let mut is_te = false;
+ let mut is_te_chunked = false;
+ let mut wants_upgrade = subject.0 == Method::CONNECT;
+
+ let mut header_case_map = if ctx.preserve_header_case {
+ Some(HeaderCaseMap::default())
+ } else {
+ None
+ };
+
+ #[cfg(feature = "ffi")]
+ let mut header_order = if ctx.preserve_header_order {
+ Some(OriginalHeaderOrder::default())
+ } else {
+ None
+ };
+
+ let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new);
+
+ headers.reserve(headers_len);
+
+ for header in &headers_indices[..headers_len] {
+ // SAFETY: array is valid up to `headers_len`
+ let header = unsafe { &*header.as_ptr() };
+ let name = header_name!(&slice[header.name.0..header.name.1]);
+ let value = header_value!(slice.slice(header.value.0..header.value.1));
+
+ match name {
+ header::TRANSFER_ENCODING => {
+ // https://tools.ietf.org/html/rfc7230#section-3.3.3
+ // If Transfer-Encoding header is present, and 'chunked' is
+ // not the final encoding, and this is a Request, then it is
+ // malformed. A server should respond with 400 Bad Request.
+ if !is_http_11 {
+ debug!("HTTP/1.0 cannot have Transfer-Encoding header");
+ return Err(Parse::transfer_encoding_unexpected());
+ }
+ is_te = true;
+ if headers::is_chunked_(&value) {
+ is_te_chunked = true;
+ decoder = DecodedLength::CHUNKED;
+ } else {
+ is_te_chunked = false;
+ }
+ }
+ header::CONTENT_LENGTH => {
+ if is_te {
+ continue;
+ }
+ let len = headers::content_length_parse(&value)
+ .ok_or_else(Parse::content_length_invalid)?;
+ if let Some(prev) = con_len {
+ if prev != len {
+ debug!(
+ "multiple Content-Length headers with different values: [{}, {}]",
+ prev, len,
+ );
+ return Err(Parse::content_length_invalid());
+ }
+ // we don't need to append this secondary length
+ continue;
+ }
+ decoder = DecodedLength::checked_new(len)?;
+ con_len = Some(len);
+ }
+ header::CONNECTION => {
+ // keep_alive was previously set to default for Version
+ if keep_alive {
+ // HTTP/1.1
+ keep_alive = !headers::connection_close(&value);
+ } else {
+ // HTTP/1.0
+ keep_alive = headers::connection_keep_alive(&value);
+ }
+ }
+ header::EXPECT => {
+ // According to https://datatracker.ietf.org/doc/html/rfc2616#section-14.20
+ // Comparison of expectation values is case-insensitive for unquoted tokens
+ // (including the 100-continue token)
+ expect_continue = value.as_bytes().eq_ignore_ascii_case(b"100-continue");
+ }
+ header::UPGRADE => {
+ // Upgrades are only allowed with HTTP/1.1
+ wants_upgrade = is_http_11;
+ }
+
+ _ => (),
+ }
+
+ if let Some(ref mut header_case_map) = header_case_map {
+ header_case_map.append(&name, slice.slice(header.name.0..header.name.1));
+ }
+
+ #[cfg(feature = "ffi")]
+ if let Some(ref mut header_order) = header_order {
+ header_order.append(&name);
+ }
+
+ headers.append(name, value);
+ }
+
+ if is_te && !is_te_chunked {
+ debug!("request with transfer-encoding header, but not chunked, bad request");
+ return Err(Parse::transfer_encoding_invalid());
+ }
+
+ let mut extensions = http::Extensions::default();
+
+ if let Some(header_case_map) = header_case_map {
+ extensions.insert(header_case_map);
+ }
+
+ #[cfg(feature = "ffi")]
+ if let Some(header_order) = header_order {
+ extensions.insert(header_order);
+ }
+
+ *ctx.req_method = Some(subject.0.clone());
+
+ Ok(Some(ParsedMessage {
+ head: MessageHead {
+ version,
+ subject,
+ headers,
+ extensions,
+ },
+ decode: decoder,
+ expect_continue,
+ keep_alive,
+ wants_upgrade,
+ }))
+ }
+
+ fn encode(mut msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> {
+ trace!(
+ "Server::encode status={:?}, body={:?}, req_method={:?}",
+ msg.head.subject,
+ msg.body,
+ msg.req_method
+ );
+
+ let mut wrote_len = false;
+
+ // hyper currently doesn't support returning 1xx status codes as a Response
+ // This is because Service only allows returning a single Response, and
+ // so if you try to reply with a e.g. 100 Continue, you have no way of
+ // replying with the latter status code response.
+ let (ret, is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS {
+ (Ok(()), true)
+ } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() {
+ // Sending content-length or transfer-encoding header on 2xx response
+ // to CONNECT is forbidden in RFC 7231.
+ wrote_len = true;
+ (Ok(()), true)
+ } else if msg.head.subject.is_informational() {
+ warn!("response with 1xx status code not supported");
+ *msg.head = MessageHead::default();
+ msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR;
+ msg.body = None;
+ (Err(crate::Error::new_user_unsupported_status_code()), true)
+ } else {
+ (Ok(()), !msg.keep_alive)
+ };
+
+ // In some error cases, we don't know about the invalid message until already
+ // pushing some bytes onto the `dst`. In those cases, we don't want to send
+ // the half-pushed message, so rewind to before.
+ let orig_len = dst.len();
+
+ let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE;
+ dst.reserve(init_cap);
+
+ let custom_reason_phrase = msg.head.extensions.get::<crate::ext::ReasonPhrase>();
+
+ if msg.head.version == Version::HTTP_11
+ && msg.head.subject == StatusCode::OK
+ && custom_reason_phrase.is_none()
+ {
+ extend(dst, b"HTTP/1.1 200 OK\r\n");
+ } else {
+ match msg.head.version {
+ Version::HTTP_10 => extend(dst, b"HTTP/1.0 "),
+ Version::HTTP_11 => extend(dst, b"HTTP/1.1 "),
+ Version::HTTP_2 => {
+ debug!("response with HTTP2 version coerced to HTTP/1.1");
+ extend(dst, b"HTTP/1.1 ");
+ }
+ other => panic!("unexpected response version: {:?}", other),
+ }
+
+ extend(dst, msg.head.subject.as_str().as_bytes());
+ extend(dst, b" ");
+
+ if let Some(reason) = custom_reason_phrase {
+ extend(dst, reason.as_bytes());
+ } else {
+ // a reason MUST be written, as many parsers will expect it.
+ extend(
+ dst,
+ msg.head
+ .subject
+ .canonical_reason()
+ .unwrap_or("<none>")
+ .as_bytes(),
+ );
+ }
+
+ extend(dst, b"\r\n");
+ }
+
+ let orig_headers;
+ let extensions = std::mem::take(&mut msg.head.extensions);
+ let orig_headers = match extensions.get::<HeaderCaseMap>() {
+ None if msg.title_case_headers => {
+ orig_headers = HeaderCaseMap::default();
+ Some(&orig_headers)
+ }
+ orig_headers => orig_headers,
+ };
+ let encoder = if let Some(orig_headers) = orig_headers {
+ Self::encode_headers_with_original_case(
+ msg,
+ dst,
+ is_last,
+ orig_len,
+ wrote_len,
+ orig_headers,
+ )?
+ } else {
+ Self::encode_headers_with_lower_case(msg, dst, is_last, orig_len, wrote_len)?
+ };
+
+ ret.map(|()| encoder)
+ }
+
+ fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> {
+ use crate::error::Kind;
+ let status = match *err.kind() {
+ Kind::Parse(Parse::Method)
+ | Kind::Parse(Parse::Header(_))
+ | Kind::Parse(Parse::Uri)
+ | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST,
+ Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
+ Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG,
+ _ => return None,
+ };
+
+ debug!("sending automatic response ({}) for parse error", status);
+ let mut msg = MessageHead::default();
+ msg.subject = status;
+ Some(msg)
+ }
+
+ fn is_server() -> bool {
+ true
+ }
+
+ fn update_date() {
+ date::update();
+ }
+}
+
+#[cfg(feature = "server")]
+impl Server {
+ fn can_have_body(method: &Option<Method>, status: StatusCode) -> bool {
+ Server::can_chunked(method, status)
+ }
+
+ fn can_chunked(method: &Option<Method>, status: StatusCode) -> bool {
+ if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success()
+ {
+ false
+ } else if status.is_informational() {
+ false
+ } else {
+ match status {
+ StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false,
+ _ => true,
+ }
+ }
+ }
+
+ fn can_have_content_length(method: &Option<Method>, status: StatusCode) -> bool {
+ if status.is_informational() || method == &Some(Method::CONNECT) && status.is_success() {
+ false
+ } else {
+ match status {
+ StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false,
+ _ => true,
+ }
+ }
+ }
+
+ fn can_have_implicit_zero_content_length(method: &Option<Method>, status: StatusCode) -> bool {
+ Server::can_have_content_length(method, status) && method != &Some(Method::HEAD)
+ }
+
+ fn encode_headers_with_lower_case(
+ msg: Encode<'_, StatusCode>,
+ dst: &mut Vec<u8>,
+ is_last: bool,
+ orig_len: usize,
+ wrote_len: bool,
+ ) -> crate::Result<Encoder> {
+ struct LowercaseWriter;
+
+ impl HeaderNameWriter for LowercaseWriter {
+ #[inline]
+ fn write_full_header_line(
+ &mut self,
+ dst: &mut Vec<u8>,
+ line: &str,
+ _: (HeaderName, &str),
+ ) {
+ extend(dst, line.as_bytes())
+ }
+
+ #[inline]
+ fn write_header_name_with_colon(
+ &mut self,
+ dst: &mut Vec<u8>,
+ name_with_colon: &str,
+ _: HeaderName,
+ ) {
+ extend(dst, name_with_colon.as_bytes())
+ }
+
+ #[inline]
+ fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) {
+ extend(dst, name.as_str().as_bytes())
+ }
+ }
+
+ Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, LowercaseWriter)
+ }
+
+ #[cold]
+ #[inline(never)]
+ fn encode_headers_with_original_case(
+ msg: Encode<'_, StatusCode>,
+ dst: &mut Vec<u8>,
+ is_last: bool,
+ orig_len: usize,
+ wrote_len: bool,
+ orig_headers: &HeaderCaseMap,
+ ) -> crate::Result<Encoder> {
+ struct OrigCaseWriter<'map> {
+ map: &'map HeaderCaseMap,
+ current: Option<(HeaderName, ValueIter<'map, Bytes>)>,
+ title_case_headers: bool,
+ }
+
+ impl HeaderNameWriter for OrigCaseWriter<'_> {
+ #[inline]
+ fn write_full_header_line(
+ &mut self,
+ dst: &mut Vec<u8>,
+ _: &str,
+ (name, rest): (HeaderName, &str),
+ ) {
+ self.write_header_name(dst, &name);
+ extend(dst, rest.as_bytes());
+ }
+
+ #[inline]
+ fn write_header_name_with_colon(
+ &mut self,
+ dst: &mut Vec<u8>,
+ _: &str,
+ name: HeaderName,
+ ) {
+ self.write_header_name(dst, &name);
+ extend(dst, b": ");
+ }
+
+ #[inline]
+ fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) {
+ let Self {
+ map,
+ ref mut current,
+ title_case_headers,
+ } = *self;
+ if current.as_ref().map_or(true, |(last, _)| last != name) {
+ *current = None;
+ }
+ let (_, values) =
+ current.get_or_insert_with(|| (name.clone(), map.get_all_internal(name)));
+
+ if let Some(orig_name) = values.next() {
+ extend(dst, orig_name);
+ } else if title_case_headers {
+ title_case(dst, name.as_str().as_bytes());
+ } else {
+ extend(dst, name.as_str().as_bytes());
+ }
+ }
+ }
+
+ let header_name_writer = OrigCaseWriter {
+ map: orig_headers,
+ current: None,
+ title_case_headers: msg.title_case_headers,
+ };
+
+ Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, header_name_writer)
+ }
+
+ #[inline]
+ fn encode_headers<W>(
+ msg: Encode<'_, StatusCode>,
+ dst: &mut Vec<u8>,
+ mut is_last: bool,
+ orig_len: usize,
+ mut wrote_len: bool,
+ mut header_name_writer: W,
+ ) -> crate::Result<Encoder>
+ where
+ W: HeaderNameWriter,
+ {
+ // In some error cases, we don't know about the invalid message until already
+ // pushing some bytes onto the `dst`. In those cases, we don't want to send
+ // the half-pushed message, so rewind to before.
+ let rewind = |dst: &mut Vec<u8>| {
+ dst.truncate(orig_len);
+ };
+
+ let mut encoder = Encoder::length(0);
+ let mut wrote_date = false;
+ let mut cur_name = None;
+ let mut is_name_written = false;
+ let mut must_write_chunked = false;
+ let mut prev_con_len = None;
+
+ macro_rules! handle_is_name_written {
+ () => {{
+ if is_name_written {
+ // we need to clean up and write the newline
+ debug_assert_ne!(
+ &dst[dst.len() - 2..],
+ b"\r\n",
+ "previous header wrote newline but set is_name_written"
+ );
+
+ if must_write_chunked {
+ extend(dst, b", chunked\r\n");
+ } else {
+ extend(dst, b"\r\n");
+ }
+ }
+ }};
+ }
+
+ 'headers: for (opt_name, value) in msg.head.headers.drain() {
+ if let Some(n) = opt_name {
+ cur_name = Some(n);
+ handle_is_name_written!();
+ is_name_written = false;
+ }
+ let name = cur_name.as_ref().expect("current header name");
+ match *name {
+ header::CONTENT_LENGTH => {
+ if wrote_len && !is_name_written {
+ warn!("unexpected content-length found, canceling");
+ rewind(dst);
+ return Err(crate::Error::new_user_header());
+ }
+ match msg.body {
+ Some(BodyLength::Known(known_len)) => {
+ // The HttpBody claims to know a length, and
+ // the headers are already set. For performance
+ // reasons, we are just going to trust that
+ // the values match.
+ //
+ // In debug builds, we'll assert they are the
+ // same to help developers find bugs.
+ #[cfg(debug_assertions)]
+ {
+ if let Some(len) = headers::content_length_parse(&value) {
+ assert!(
+ len == known_len,
+ "payload claims content-length of {}, custom content-length header claims {}",
+ known_len,
+ len,
+ );
+ }
+ }
+
+ if !is_name_written {
+ encoder = Encoder::length(known_len);
+ header_name_writer.write_header_name_with_colon(
+ dst,
+ "content-length: ",
+ header::CONTENT_LENGTH,
+ );
+ extend(dst, value.as_bytes());
+ wrote_len = true;
+ is_name_written = true;
+ }
+ continue 'headers;
+ }
+ Some(BodyLength::Unknown) => {
+ // The HttpBody impl didn't know how long the
+ // body is, but a length header was included.
+ // We have to parse the value to return our
+ // Encoder...
+
+ if let Some(len) = headers::content_length_parse(&value) {
+ if let Some(prev) = prev_con_len {
+ if prev != len {
+ warn!(
+ "multiple Content-Length values found: [{}, {}]",
+ prev, len
+ );
+ rewind(dst);
+ return Err(crate::Error::new_user_header());
+ }
+ debug_assert!(is_name_written);
+ continue 'headers;
+ } else {
+ // we haven't written content-length yet!
+ encoder = Encoder::length(len);
+ header_name_writer.write_header_name_with_colon(
+ dst,
+ "content-length: ",
+ header::CONTENT_LENGTH,
+ );
+ extend(dst, value.as_bytes());
+ wrote_len = true;
+ is_name_written = true;
+ prev_con_len = Some(len);
+ continue 'headers;
+ }
+ } else {
+ warn!("illegal Content-Length value: {:?}", value);
+ rewind(dst);
+ return Err(crate::Error::new_user_header());
+ }
+ }
+ None => {
+ // We have no body to actually send,
+ // but the headers claim a content-length.
+ // There's only 2 ways this makes sense:
+ //
+ // - The header says the length is `0`.
+ // - This is a response to a `HEAD` request.
+ if msg.req_method == &Some(Method::HEAD) {
+ debug_assert_eq!(encoder, Encoder::length(0));
+ } else {
+ if value.as_bytes() != b"0" {
+ warn!(
+ "content-length value found, but empty body provided: {:?}",
+ value
+ );
+ }
+ continue 'headers;
+ }
+ }
+ }
+ wrote_len = true;
+ }
+ header::TRANSFER_ENCODING => {
+ if wrote_len && !is_name_written {
+ warn!("unexpected transfer-encoding found, canceling");
+ rewind(dst);
+ return Err(crate::Error::new_user_header());
+ }
+ // check that we actually can send a chunked body...
+ if msg.head.version == Version::HTTP_10
+ || !Server::can_chunked(msg.req_method, msg.head.subject)
+ {
+ continue;
+ }
+ wrote_len = true;
+ // Must check each value, because `chunked` needs to be the
+ // last encoding, or else we add it.
+ must_write_chunked = !headers::is_chunked_(&value);
+
+ if !is_name_written {
+ encoder = Encoder::chunked();
+ is_name_written = true;
+ header_name_writer.write_header_name_with_colon(
+ dst,
+ "transfer-encoding: ",
+ header::TRANSFER_ENCODING,
+ );
+ extend(dst, value.as_bytes());
+ } else {
+ extend(dst, b", ");
+ extend(dst, value.as_bytes());
+ }
+ continue 'headers;
+ }
+ header::CONNECTION => {
+ if !is_last && headers::connection_close(&value) {
+ is_last = true;
+ }
+ if !is_name_written {
+ is_name_written = true;
+ header_name_writer.write_header_name_with_colon(
+ dst,
+ "connection: ",
+ header::CONNECTION,
+ );
+ extend(dst, value.as_bytes());
+ } else {
+ extend(dst, b", ");
+ extend(dst, value.as_bytes());
+ }
+ continue 'headers;
+ }
+ header::DATE => {
+ wrote_date = true;
+ }
+ _ => (),
+ }
+ //TODO: this should perhaps instead combine them into
+ //single lines, as RFC7230 suggests is preferable.
+
+ // non-special write Name and Value
+ debug_assert!(
+ !is_name_written,
+ "{:?} set is_name_written and didn't continue loop",
+ name,
+ );
+ header_name_writer.write_header_name(dst, name);
+ extend(dst, b": ");
+ extend(dst, value.as_bytes());
+ extend(dst, b"\r\n");
+ }
+
+ handle_is_name_written!();
+
+ if !wrote_len {
+ encoder = match msg.body {
+ Some(BodyLength::Unknown) => {
+ if msg.head.version == Version::HTTP_10
+ || !Server::can_chunked(msg.req_method, msg.head.subject)
+ {
+ Encoder::close_delimited()
+ } else {
+ header_name_writer.write_full_header_line(
+ dst,
+ "transfer-encoding: chunked\r\n",
+ (header::TRANSFER_ENCODING, ": chunked\r\n"),
+ );
+ Encoder::chunked()
+ }
+ }
+ None | Some(BodyLength::Known(0)) => {
+ if Server::can_have_implicit_zero_content_length(
+ msg.req_method,
+ msg.head.subject,
+ ) {
+ header_name_writer.write_full_header_line(
+ dst,
+ "content-length: 0\r\n",
+ (header::CONTENT_LENGTH, ": 0\r\n"),
+ )
+ }
+ Encoder::length(0)
+ }
+ Some(BodyLength::Known(len)) => {
+ if !Server::can_have_content_length(msg.req_method, msg.head.subject) {
+ Encoder::length(0)
+ } else {
+ header_name_writer.write_header_name_with_colon(
+ dst,
+ "content-length: ",
+ header::CONTENT_LENGTH,
+ );
+ extend(dst, ::itoa::Buffer::new().format(len).as_bytes());
+ extend(dst, b"\r\n");
+ Encoder::length(len)
+ }
+ }
+ };
+ }
+
+ if !Server::can_have_body(msg.req_method, msg.head.subject) {
+ trace!(
+ "server body forced to 0; method={:?}, status={:?}",
+ msg.req_method,
+ msg.head.subject
+ );
+ encoder = Encoder::length(0);
+ }
+
+ // cached date is much faster than formatting every request
+ if !wrote_date {
+ dst.reserve(date::DATE_VALUE_LENGTH + 8);
+ header_name_writer.write_header_name_with_colon(dst, "date: ", header::DATE);
+ date::extend(dst);
+ extend(dst, b"\r\n\r\n");
+ } else {
+ extend(dst, b"\r\n");
+ }
+
+ Ok(encoder.set_last(is_last))
+ }
+}
+
+#[cfg(feature = "server")]
+trait HeaderNameWriter {
+ fn write_full_header_line(
+ &mut self,
+ dst: &mut Vec<u8>,
+ line: &str,
+ name_value_pair: (HeaderName, &str),
+ );
+ fn write_header_name_with_colon(
+ &mut self,
+ dst: &mut Vec<u8>,
+ name_with_colon: &str,
+ name: HeaderName,
+ );
+ fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName);
+}
+
+#[cfg(feature = "client")]
+impl Http1Transaction for Client {
+ type Incoming = StatusCode;
+ type Outgoing = RequestLine;
+ const LOG: &'static str = "{role=client}";
+
+ fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> {
+ debug_assert!(!buf.is_empty(), "parse called with empty buf");
+
+ // Loop to skip information status code headers (100 Continue, etc).
+ loop {
+ // Unsafe: see comment in Server Http1Transaction, above.
+ let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe {
+ // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
+ MaybeUninit::uninit().assume_init()
+ };
+ let (len, status, reason, version, headers_len) = {
+ // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
+ let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
+ unsafe { MaybeUninit::uninit().assume_init() };
+ trace!(bytes = buf.len(), "Response.parse");
+ let mut res = httparse::Response::new(&mut []);
+ let bytes = buf.as_ref();
+ match ctx.h1_parser_config.parse_response_with_uninit_headers(
+ &mut res,
+ bytes,
+ &mut headers,
+ ) {
+ Ok(httparse::Status::Complete(len)) => {
+ trace!("Response.parse Complete({})", len);
+ let status = StatusCode::from_u16(res.code.unwrap())?;
+
+ let reason = {
+ let reason = res.reason.unwrap();
+ // Only save the reason phrase if it isn't the canonical reason
+ if Some(reason) != status.canonical_reason() {
+ Some(Bytes::copy_from_slice(reason.as_bytes()))
+ } else {
+ None
+ }
+ };
+
+ let version = if res.version.unwrap() == 1 {
+ Version::HTTP_11
+ } else {
+ Version::HTTP_10
+ };
+ record_header_indices(bytes, &res.headers, &mut headers_indices)?;
+ let headers_len = res.headers.len();
+ (len, status, reason, version, headers_len)
+ }
+ Ok(httparse::Status::Partial) => return Ok(None),
+ Err(httparse::Error::Version) if ctx.h09_responses => {
+ trace!("Response.parse accepted HTTP/0.9 response");
+
+ (0, StatusCode::OK, None, Version::HTTP_09, 0)
+ }
+ Err(e) => return Err(e.into()),
+ }
+ };
+
+ let mut slice = buf.split_to(len);
+
+ if ctx
+ .h1_parser_config
+ .obsolete_multiline_headers_in_responses_are_allowed()
+ {
+ for header in &headers_indices[..headers_len] {
+ // SAFETY: array is valid up to `headers_len`
+ let header = unsafe { &*header.as_ptr() };
+ for b in &mut slice[header.value.0..header.value.1] {
+ if *b == b'\r' || *b == b'\n' {
+ *b = b' ';
+ }
+ }
+ }
+ }
+
+ let slice = slice.freeze();
+
+ let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new);
+
+ let mut keep_alive = version == Version::HTTP_11;
+
+ let mut header_case_map = if ctx.preserve_header_case {
+ Some(HeaderCaseMap::default())
+ } else {
+ None
+ };
+
+ #[cfg(feature = "ffi")]
+ let mut header_order = if ctx.preserve_header_order {
+ Some(OriginalHeaderOrder::default())
+ } else {
+ None
+ };
+
+ headers.reserve(headers_len);
+ for header in &headers_indices[..headers_len] {
+ // SAFETY: array is valid up to `headers_len`
+ let header = unsafe { &*header.as_ptr() };
+ let name = header_name!(&slice[header.name.0..header.name.1]);
+ let value = header_value!(slice.slice(header.value.0..header.value.1));
+
+ if let header::CONNECTION = name {
+ // keep_alive was previously set to default for Version
+ if keep_alive {
+ // HTTP/1.1
+ keep_alive = !headers::connection_close(&value);
+ } else {
+ // HTTP/1.0
+ keep_alive = headers::connection_keep_alive(&value);
+ }
+ }
+
+ if let Some(ref mut header_case_map) = header_case_map {
+ header_case_map.append(&name, slice.slice(header.name.0..header.name.1));
+ }
+
+ #[cfg(feature = "ffi")]
+ if let Some(ref mut header_order) = header_order {
+ header_order.append(&name);
+ }
+
+ headers.append(name, value);
+ }
+
+ let mut extensions = http::Extensions::default();
+
+ if let Some(header_case_map) = header_case_map {
+ extensions.insert(header_case_map);
+ }
+
+ #[cfg(feature = "ffi")]
+ if let Some(header_order) = header_order {
+ extensions.insert(header_order);
+ }
+
+ if let Some(reason) = reason {
+ // Safety: httparse ensures that only valid reason phrase bytes are present in this
+ // field.
+ let reason = unsafe { crate::ext::ReasonPhrase::from_bytes_unchecked(reason) };
+ extensions.insert(reason);
+ }
+
+ #[cfg(feature = "ffi")]
+ if ctx.raw_headers {
+ extensions.insert(crate::ffi::RawHeaders(crate::ffi::hyper_buf(slice)));
+ }
+
+ let head = MessageHead {
+ version,
+ subject: status,
+ headers,
+ extensions,
+ };
+ if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? {
+ return Ok(Some(ParsedMessage {
+ head,
+ decode,
+ expect_continue: false,
+ // a client upgrade means the connection can't be used
+ // again, as it is definitely upgrading.
+ keep_alive: keep_alive && !is_upgrade,
+ wants_upgrade: is_upgrade,
+ }));
+ }
+
+ #[cfg(feature = "ffi")]
+ if head.subject.is_informational() {
+ if let Some(callback) = ctx.on_informational {
+ callback.call(head.into_response(crate::Body::empty()));
+ }
+ }
+
+ // Parsing a 1xx response could have consumed the buffer, check if
+ // it is empty now...
+ if buf.is_empty() {
+ return Ok(None);
+ }
+ }
+ }
+
+ fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> {
+ trace!(
+ "Client::encode method={:?}, body={:?}",
+ msg.head.subject.0,
+ msg.body
+ );
+
+ *msg.req_method = Some(msg.head.subject.0.clone());
+
+ let body = Client::set_length(msg.head, msg.body);
+
+ let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE;
+ dst.reserve(init_cap);
+
+ extend(dst, msg.head.subject.0.as_str().as_bytes());
+ extend(dst, b" ");
+ //TODO: add API to http::Uri to encode without std::fmt
+ let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1);
+
+ match msg.head.version {
+ Version::HTTP_10 => extend(dst, b"HTTP/1.0"),
+ Version::HTTP_11 => extend(dst, b"HTTP/1.1"),
+ Version::HTTP_2 => {
+ debug!("request with HTTP2 version coerced to HTTP/1.1");
+ extend(dst, b"HTTP/1.1");
+ }
+ other => panic!("unexpected request version: {:?}", other),
+ }
+ extend(dst, b"\r\n");
+
+ if let Some(orig_headers) = msg.head.extensions.get::<HeaderCaseMap>() {
+ write_headers_original_case(
+ &msg.head.headers,
+ orig_headers,
+ dst,
+ msg.title_case_headers,
+ );
+ } else if msg.title_case_headers {
+ write_headers_title_case(&msg.head.headers, dst);
+ } else {
+ write_headers(&msg.head.headers, dst);
+ }
+
+ extend(dst, b"\r\n");
+ msg.head.headers.clear(); //TODO: remove when switching to drain()
+
+ Ok(body)
+ }
+
+ fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> {
+ // we can't tell the server about any errors it creates
+ None
+ }
+
+ fn is_client() -> bool {
+ true
+ }
+}
+
+#[cfg(feature = "client")]
+impl Client {
+ /// Returns Some(length, wants_upgrade) if successful.
+ ///
+ /// Returns None if this message head should be skipped (like a 100 status).
+ fn decoder(
+ inc: &MessageHead<StatusCode>,
+ method: &mut Option<Method>,
+ ) -> Result<Option<(DecodedLength, bool)>, Parse> {
+ // According to https://tools.ietf.org/html/rfc7230#section-3.3.3
+ // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body.
+ // 2. Status 2xx to a CONNECT cannot have a body.
+ // 3. Transfer-Encoding: chunked has a chunked body.
+ // 4. If multiple differing Content-Length headers or invalid, close connection.
+ // 5. Content-Length header has a sized body.
+ // 6. (irrelevant to Response)
+ // 7. Read till EOF.
+
+ match inc.subject.as_u16() {
+ 101 => {
+ return Ok(Some((DecodedLength::ZERO, true)));
+ }
+ 100 | 102..=199 => {
+ trace!("ignoring informational response: {}", inc.subject.as_u16());
+ return Ok(None);
+ }
+ 204 | 304 => return Ok(Some((DecodedLength::ZERO, false))),
+ _ => (),
+ }
+ match *method {
+ Some(Method::HEAD) => {
+ return Ok(Some((DecodedLength::ZERO, false)));
+ }
+ Some(Method::CONNECT) => {
+ if let 200..=299 = inc.subject.as_u16() {
+ return Ok(Some((DecodedLength::ZERO, true)));
+ }
+ }
+ Some(_) => {}
+ None => {
+ trace!("Client::decoder is missing the Method");
+ }
+ }
+
+ if inc.headers.contains_key(header::TRANSFER_ENCODING) {
+ // https://tools.ietf.org/html/rfc7230#section-3.3.3
+ // If Transfer-Encoding header is present, and 'chunked' is
+ // not the final encoding, and this is a Request, then it is
+ // malformed. A server should respond with 400 Bad Request.
+ if inc.version == Version::HTTP_10 {
+ debug!("HTTP/1.0 cannot have Transfer-Encoding header");
+ Err(Parse::transfer_encoding_unexpected())
+ } else if headers::transfer_encoding_is_chunked(&inc.headers) {
+ Ok(Some((DecodedLength::CHUNKED, false)))
+ } else {
+ trace!("not chunked, read till eof");
+ Ok(Some((DecodedLength::CLOSE_DELIMITED, false)))
+ }
+ } else if let Some(len) = headers::content_length_parse_all(&inc.headers) {
+ Ok(Some((DecodedLength::checked_new(len)?, false)))
+ } else if inc.headers.contains_key(header::CONTENT_LENGTH) {
+ debug!("illegal Content-Length header");
+ Err(Parse::content_length_invalid())
+ } else {
+ trace!("neither Transfer-Encoding nor Content-Length");
+ Ok(Some((DecodedLength::CLOSE_DELIMITED, false)))
+ }
+ }
+ fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder {
+ let body = if let Some(body) = body {
+ body
+ } else {
+ head.headers.remove(header::TRANSFER_ENCODING);
+ return Encoder::length(0);
+ };
+
+ // HTTP/1.0 doesn't know about chunked
+ let can_chunked = head.version == Version::HTTP_11;
+ let headers = &mut head.headers;
+
+ // If the user already set specific headers, we should respect them, regardless
+ // of what the HttpBody knows about itself. They set them for a reason.
+
+ // Because of the borrow checker, we can't check the for an existing
+ // Content-Length header while holding an `Entry` for the Transfer-Encoding
+ // header, so unfortunately, we must do the check here, first.
+
+ let existing_con_len = headers::content_length_parse_all(headers);
+ let mut should_remove_con_len = false;
+
+ if !can_chunked {
+ // Chunked isn't legal, so if it is set, we need to remove it.
+ if headers.remove(header::TRANSFER_ENCODING).is_some() {
+ trace!("removing illegal transfer-encoding header");
+ }
+
+ return if let Some(len) = existing_con_len {
+ Encoder::length(len)
+ } else if let BodyLength::Known(len) = body {
+ set_content_length(headers, len)
+ } else {
+ // HTTP/1.0 client requests without a content-length
+ // cannot have any body at all.
+ Encoder::length(0)
+ };
+ }
+
+ // If the user set a transfer-encoding, respect that. Let's just
+ // make sure `chunked` is the final encoding.
+ let encoder = match headers.entry(header::TRANSFER_ENCODING) {
+ Entry::Occupied(te) => {
+ should_remove_con_len = true;
+ if headers::is_chunked(te.iter()) {
+ Some(Encoder::chunked())
+ } else {
+ warn!("user provided transfer-encoding does not end in 'chunked'");
+
+ // There's a Transfer-Encoding, but it doesn't end in 'chunked'!
+ // An example that could trigger this:
+ //
+ // Transfer-Encoding: gzip
+ //
+ // This can be bad, depending on if this is a request or a
+ // response.
+ //
+ // - A request is illegal if there is a `Transfer-Encoding`
+ // but it doesn't end in `chunked`.
+ // - A response that has `Transfer-Encoding` but doesn't
+ // end in `chunked` isn't illegal, it just forces this
+ // to be close-delimited.
+ //
+ // We can try to repair this, by adding `chunked` ourselves.
+
+ headers::add_chunked(te);
+ Some(Encoder::chunked())
+ }
+ }
+ Entry::Vacant(te) => {
+ if let Some(len) = existing_con_len {
+ Some(Encoder::length(len))
+ } else if let BodyLength::Unknown = body {
+ // GET, HEAD, and CONNECT almost never have bodies.
+ //
+ // So instead of sending a "chunked" body with a 0-chunk,
+ // assume no body here. If you *must* send a body,
+ // set the headers explicitly.
+ match head.subject.0 {
+ Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)),
+ _ => {
+ te.insert(HeaderValue::from_static("chunked"));
+ Some(Encoder::chunked())
+ }
+ }
+ } else {
+ None
+ }
+ }
+ };
+
+ // This is because we need a second mutable borrow to remove
+ // content-length header.
+ if let Some(encoder) = encoder {
+ if should_remove_con_len && existing_con_len.is_some() {
+ headers.remove(header::CONTENT_LENGTH);
+ }
+ return encoder;
+ }
+
+ // User didn't set transfer-encoding, AND we know body length,
+ // so we can just set the Content-Length automatically.
+
+ let len = if let BodyLength::Known(len) = body {
+ len
+ } else {
+ unreachable!("BodyLength::Unknown would set chunked");
+ };
+
+ set_content_length(headers, len)
+ }
+}
+
+fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder {
+ // At this point, there should not be a valid Content-Length
+ // header. However, since we'll be indexing in anyways, we can
+ // warn the user if there was an existing illegal header.
+ //
+ // Or at least, we can in theory. It's actually a little bit slower,
+ // so perhaps only do that while the user is developing/testing.
+
+ if cfg!(debug_assertions) {
+ match headers.entry(header::CONTENT_LENGTH) {
+ Entry::Occupied(mut cl) => {
+ // Internal sanity check, we should have already determined
+ // that the header was illegal before calling this function.
+ debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none());
+ // Uh oh, the user set `Content-Length` headers, but set bad ones.
+ // This would be an illegal message anyways, so let's try to repair
+ // with our known good length.
+ error!("user provided content-length header was invalid");
+
+ cl.insert(HeaderValue::from(len));
+ Encoder::length(len)
+ }
+ Entry::Vacant(cl) => {
+ cl.insert(HeaderValue::from(len));
+ Encoder::length(len)
+ }
+ }
+ } else {
+ headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
+ Encoder::length(len)
+ }
+}
+
+#[derive(Clone, Copy)]
+struct HeaderIndices {
+ name: (usize, usize),
+ value: (usize, usize),
+}
+
+fn record_header_indices(
+ bytes: &[u8],
+ headers: &[httparse::Header<'_>],
+ indices: &mut [MaybeUninit<HeaderIndices>],
+) -> Result<(), crate::error::Parse> {
+ let bytes_ptr = bytes.as_ptr() as usize;
+
+ for (header, indices) in headers.iter().zip(indices.iter_mut()) {
+ if header.name.len() >= (1 << 16) {
+ debug!("header name larger than 64kb: {:?}", header.name);
+ return Err(crate::error::Parse::TooLarge);
+ }
+ let name_start = header.name.as_ptr() as usize - bytes_ptr;
+ let name_end = name_start + header.name.len();
+ let value_start = header.value.as_ptr() as usize - bytes_ptr;
+ let value_end = value_start + header.value.len();
+
+ // FIXME(maybe_uninit_extra)
+ // FIXME(addr_of)
+ // Currently we don't have `ptr::addr_of_mut` in stable rust or
+ // MaybeUninit::write, so this is some way of assigning into a MaybeUninit
+ // safely
+ let new_header_indices = HeaderIndices {
+ name: (name_start, name_end),
+ value: (value_start, value_end),
+ };
+ *indices = MaybeUninit::new(new_header_indices);
+ }
+
+ Ok(())
+}
+
+// Write header names as title case. The header name is assumed to be ASCII.
+fn title_case(dst: &mut Vec<u8>, name: &[u8]) {
+ dst.reserve(name.len());
+
+ // Ensure first character is uppercased
+ let mut prev = b'-';
+ for &(mut c) in name {
+ if prev == b'-' {
+ c.make_ascii_uppercase();
+ }
+ dst.push(c);
+ prev = c;
+ }
+}
+
+fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
+ for (name, value) in headers {
+ title_case(dst, name.as_str().as_bytes());
+ extend(dst, b": ");
+ extend(dst, value.as_bytes());
+ extend(dst, b"\r\n");
+ }
+}
+
+fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
+ for (name, value) in headers {
+ extend(dst, name.as_str().as_bytes());
+ extend(dst, b": ");
+ extend(dst, value.as_bytes());
+ extend(dst, b"\r\n");
+ }
+}
+
+#[cold]
+fn write_headers_original_case(
+ headers: &HeaderMap,
+ orig_case: &HeaderCaseMap,
+ dst: &mut Vec<u8>,
+ title_case_headers: bool,
+) {
+ // For each header name/value pair, there may be a value in the casemap
+ // that corresponds to the HeaderValue. So, we iterator all the keys,
+ // and for each one, try to pair the originally cased name with the value.
+ //
+ // TODO: consider adding http::HeaderMap::entries() iterator
+ for name in headers.keys() {
+ let mut names = orig_case.get_all(name);
+
+ for value in headers.get_all(name) {
+ if let Some(orig_name) = names.next() {
+ extend(dst, orig_name.as_ref());
+ } else if title_case_headers {
+ title_case(dst, name.as_str().as_bytes());
+ } else {
+ extend(dst, name.as_str().as_bytes());
+ }
+
+ // Wanted for curl test cases that send `X-Custom-Header:\r\n`
+ if value.is_empty() {
+ extend(dst, b":\r\n");
+ } else {
+ extend(dst, b": ");
+ extend(dst, value.as_bytes());
+ extend(dst, b"\r\n");
+ }
+ }
+ }
+}
+
+struct FastWrite<'a>(&'a mut Vec<u8>);
+
+impl<'a> fmt::Write for FastWrite<'a> {
+ #[inline]
+ fn write_str(&mut self, s: &str) -> fmt::Result {
+ extend(self.0, s.as_bytes());
+ Ok(())
+ }
+
+ #[inline]
+ fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result {
+ fmt::write(self, args)
+ }
+}
+
+#[inline]
+fn extend(dst: &mut Vec<u8>, data: &[u8]) {
+ dst.extend_from_slice(data);
+}
+
+#[cfg(test)]
+mod tests {
+ use bytes::BytesMut;
+
+ use super::*;
+
+ #[test]
+ fn test_parse_request() {
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from("GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n");
+ let mut method = None;
+ let msg = Server::parse(
+ &mut raw,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut method,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(raw.len(), 0);
+ assert_eq!(msg.head.subject.0, crate::Method::GET);
+ assert_eq!(msg.head.subject.1, "/echo");
+ assert_eq!(msg.head.version, crate::Version::HTTP_11);
+ assert_eq!(msg.head.headers.len(), 1);
+ assert_eq!(msg.head.headers["Host"], "hyper.rs");
+ assert_eq!(method, Some(crate::Method::GET));
+ }
+
+ #[test]
+ fn test_parse_response() {
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(crate::Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ let msg = Client::parse(&mut raw, ctx).unwrap().unwrap();
+ assert_eq!(raw.len(), 0);
+ assert_eq!(msg.head.subject, crate::StatusCode::OK);
+ assert_eq!(msg.head.version, crate::Version::HTTP_11);
+ assert_eq!(msg.head.headers.len(), 1);
+ assert_eq!(msg.head.headers["Content-Length"], "0");
+ }
+
+ #[test]
+ fn test_parse_request_errors() {
+ let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n");
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ Server::parse(&mut raw, ctx).unwrap_err();
+ }
+
+ const H09_RESPONSE: &'static str = "Baguettes are super delicious, don't you agree?";
+
+ #[test]
+ fn test_parse_response_h09_allowed() {
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from(H09_RESPONSE);
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(crate::Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: true,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ let msg = Client::parse(&mut raw, ctx).unwrap().unwrap();
+ assert_eq!(raw, H09_RESPONSE);
+ assert_eq!(msg.head.subject, crate::StatusCode::OK);
+ assert_eq!(msg.head.version, crate::Version::HTTP_09);
+ assert_eq!(msg.head.headers.len(), 0);
+ }
+
+ #[test]
+ fn test_parse_response_h09_rejected() {
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from(H09_RESPONSE);
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(crate::Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ Client::parse(&mut raw, ctx).unwrap_err();
+ assert_eq!(raw, H09_RESPONSE);
+ }
+
+ const RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON: &'static str =
+ "HTTP/1.1 200 OK\r\nAccess-Control-Allow-Credentials : true\r\n\r\n";
+
+ #[test]
+ fn test_parse_allow_response_with_spaces_before_colons() {
+ use httparse::ParserConfig;
+
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON);
+ let mut h1_parser_config = ParserConfig::default();
+ h1_parser_config.allow_spaces_after_header_name_in_responses(true);
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(crate::Method::GET),
+ h1_parser_config,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ let msg = Client::parse(&mut raw, ctx).unwrap().unwrap();
+ assert_eq!(raw.len(), 0);
+ assert_eq!(msg.head.subject, crate::StatusCode::OK);
+ assert_eq!(msg.head.version, crate::Version::HTTP_11);
+ assert_eq!(msg.head.headers.len(), 1);
+ assert_eq!(msg.head.headers["Access-Control-Allow-Credentials"], "true");
+ }
+
+ #[test]
+ fn test_parse_reject_response_with_spaces_before_colons() {
+ let _ = pretty_env_logger::try_init();
+ let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON);
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(crate::Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ Client::parse(&mut raw, ctx).unwrap_err();
+ }
+
+ #[test]
+ fn test_parse_preserve_header_case_in_request() {
+ let mut raw =
+ BytesMut::from("GET / HTTP/1.1\r\nHost: hyper.rs\r\nX-BREAD: baguette\r\n\r\n");
+ let ctx = ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: true,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ };
+ let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap();
+ let orig_headers = parsed_message
+ .head
+ .extensions
+ .get::<HeaderCaseMap>()
+ .unwrap();
+ assert_eq!(
+ orig_headers
+ .get_all_internal(&HeaderName::from_static("host"))
+ .into_iter()
+ .collect::<Vec<_>>(),
+ vec![&Bytes::from("Host")]
+ );
+ assert_eq!(
+ orig_headers
+ .get_all_internal(&HeaderName::from_static("x-bread"))
+ .into_iter()
+ .collect::<Vec<_>>(),
+ vec![&Bytes::from("X-BREAD")]
+ );
+ }
+
+ #[test]
+ fn test_decoder_request() {
+ fn parse(s: &str) -> ParsedMessage<RequestLine> {
+ let mut bytes = BytesMut::from(s);
+ Server::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .expect("parse ok")
+ .expect("parse complete")
+ }
+
+ fn parse_err(s: &str, comment: &str) -> crate::error::Parse {
+ let mut bytes = BytesMut::from(s);
+ Server::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .expect_err(comment)
+ }
+
+ // no length or transfer-encoding means 0-length body
+ assert_eq!(
+ parse(
+ "\
+ GET / HTTP/1.1\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::ZERO
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::ZERO
+ );
+
+ // transfer-encoding: chunked
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: gzip, chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: gzip\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ // content-length
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ content-length: 10\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::new(10)
+ );
+
+ // transfer-encoding and content-length = chunked
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ content-length: 10\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: chunked\r\n\
+ content-length: 10\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: gzip\r\n\
+ content-length: 10\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ // multiple content-lengths of same value are fine
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.1\r\n\
+ content-length: 10\r\n\
+ content-length: 10\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::new(10)
+ );
+
+ // multiple content-lengths with different values is an error
+ parse_err(
+ "\
+ POST / HTTP/1.1\r\n\
+ content-length: 10\r\n\
+ content-length: 11\r\n\
+ \r\n\
+ ",
+ "multiple content-lengths",
+ );
+
+ // content-length with prefix is not allowed
+ parse_err(
+ "\
+ POST / HTTP/1.1\r\n\
+ content-length: +10\r\n\
+ \r\n\
+ ",
+ "prefixed content-length",
+ );
+
+ // transfer-encoding that isn't chunked is an error
+ parse_err(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: gzip\r\n\
+ \r\n\
+ ",
+ "transfer-encoding but not chunked",
+ );
+
+ parse_err(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: chunked, gzip\r\n\
+ \r\n\
+ ",
+ "transfer-encoding doesn't end in chunked",
+ );
+
+ parse_err(
+ "\
+ POST / HTTP/1.1\r\n\
+ transfer-encoding: chunked\r\n\
+ transfer-encoding: afterlol\r\n\
+ \r\n\
+ ",
+ "transfer-encoding multiple lines doesn't end in chunked",
+ );
+
+ // http/1.0
+
+ assert_eq!(
+ parse(
+ "\
+ POST / HTTP/1.0\r\n\
+ content-length: 10\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::new(10)
+ );
+
+ // 1.0 doesn't understand chunked, so its an error
+ parse_err(
+ "\
+ POST / HTTP/1.0\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ ",
+ "1.0 chunked",
+ );
+ }
+
+ #[test]
+ fn test_decoder_response() {
+ fn parse(s: &str) -> ParsedMessage<StatusCode> {
+ parse_with_method(s, Method::GET)
+ }
+
+ fn parse_ignores(s: &str) {
+ let mut bytes = BytesMut::from(s);
+ assert!(Client::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ }
+ )
+ .expect("parse ok")
+ .is_none())
+ }
+
+ fn parse_with_method(s: &str, m: Method) -> ParsedMessage<StatusCode> {
+ let mut bytes = BytesMut::from(s);
+ Client::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(m),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .expect("parse ok")
+ .expect("parse complete")
+ }
+
+ fn parse_err(s: &str) -> crate::error::Parse {
+ let mut bytes = BytesMut::from(s);
+ Client::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .expect_err("parse should err")
+ }
+
+ // no content-length or transfer-encoding means close-delimited
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CLOSE_DELIMITED
+ );
+
+ // 204 and 304 never have a body
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 204 No Content\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::ZERO
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 304 Not Modified\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::ZERO
+ );
+
+ // content-length
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 8\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::new(8)
+ );
+
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 8\r\n\
+ content-length: 8\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::new(8)
+ );
+
+ parse_err(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 8\r\n\
+ content-length: 9\r\n\
+ \r\n\
+ ",
+ );
+
+ parse_err(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: +8\r\n\
+ \r\n\
+ ",
+ );
+
+ // transfer-encoding: chunked
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ // transfer-encoding not-chunked is close-delimited
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ transfer-encoding: yolo\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CLOSE_DELIMITED
+ );
+
+ // transfer-encoding and content-length = chunked
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 10\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CHUNKED
+ );
+
+ // HEAD can have content-length, but not body
+ assert_eq!(
+ parse_with_method(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 8\r\n\
+ \r\n\
+ ",
+ Method::HEAD
+ )
+ .decode,
+ DecodedLength::ZERO
+ );
+
+ // CONNECT with 200 never has body
+ {
+ let msg = parse_with_method(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ \r\n\
+ ",
+ Method::CONNECT,
+ );
+ assert_eq!(msg.decode, DecodedLength::ZERO);
+ assert!(!msg.keep_alive, "should be upgrade");
+ assert!(msg.wants_upgrade, "should be upgrade");
+ }
+
+ // CONNECT receiving non 200 can have a body
+ assert_eq!(
+ parse_with_method(
+ "\
+ HTTP/1.1 400 Bad Request\r\n\
+ \r\n\
+ ",
+ Method::CONNECT
+ )
+ .decode,
+ DecodedLength::CLOSE_DELIMITED
+ );
+
+ // 1xx status codes
+ parse_ignores(
+ "\
+ HTTP/1.1 100 Continue\r\n\
+ \r\n\
+ ",
+ );
+
+ parse_ignores(
+ "\
+ HTTP/1.1 103 Early Hints\r\n\
+ \r\n\
+ ",
+ );
+
+ // 101 upgrade not supported yet
+ {
+ let msg = parse(
+ "\
+ HTTP/1.1 101 Switching Protocols\r\n\
+ \r\n\
+ ",
+ );
+ assert_eq!(msg.decode, DecodedLength::ZERO);
+ assert!(!msg.keep_alive, "should be last");
+ assert!(msg.wants_upgrade, "should be upgrade");
+ }
+
+ // http/1.0
+ assert_eq!(
+ parse(
+ "\
+ HTTP/1.0 200 OK\r\n\
+ \r\n\
+ "
+ )
+ .decode,
+ DecodedLength::CLOSE_DELIMITED
+ );
+
+ // 1.0 doesn't understand chunked
+ parse_err(
+ "\
+ HTTP/1.0 200 OK\r\n\
+ transfer-encoding: chunked\r\n\
+ \r\n\
+ ",
+ );
+
+ // keep-alive
+ assert!(
+ parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 0\r\n\
+ \r\n\
+ "
+ )
+ .keep_alive,
+ "HTTP/1.1 keep-alive is default"
+ );
+
+ assert!(
+ !parse(
+ "\
+ HTTP/1.1 200 OK\r\n\
+ content-length: 0\r\n\
+ connection: foo, close, bar\r\n\
+ \r\n\
+ "
+ )
+ .keep_alive,
+ "connection close is always close"
+ );
+
+ assert!(
+ !parse(
+ "\
+ HTTP/1.0 200 OK\r\n\
+ content-length: 0\r\n\
+ \r\n\
+ "
+ )
+ .keep_alive,
+ "HTTP/1.0 close is default"
+ );
+
+ assert!(
+ parse(
+ "\
+ HTTP/1.0 200 OK\r\n\
+ content-length: 0\r\n\
+ connection: foo, keep-alive, bar\r\n\
+ \r\n\
+ "
+ )
+ .keep_alive,
+ "connection keep-alive is always keep-alive"
+ );
+ }
+
+ #[test]
+ fn test_client_request_encode_title_case() {
+ use crate::proto::BodyLength;
+ use http::header::HeaderValue;
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+ head.headers.insert("*-*", HeaderValue::from_static("o_o"));
+
+ let mut vec = Vec::new();
+ Client::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: true,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec());
+ }
+
+ #[test]
+ fn test_client_request_encode_orig_case() {
+ use crate::proto::BodyLength;
+ use http::header::{HeaderValue, CONTENT_LENGTH};
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+
+ let mut orig_headers = HeaderCaseMap::default();
+ orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into());
+ head.extensions.insert(orig_headers);
+
+ let mut vec = Vec::new();
+ Client::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: false,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ assert_eq!(
+ &*vec,
+ b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\n\r\n"
+ .as_ref(),
+ );
+ }
+ #[test]
+ fn test_client_request_encode_orig_and_title_case() {
+ use crate::proto::BodyLength;
+ use http::header::{HeaderValue, CONTENT_LENGTH};
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+
+ let mut orig_headers = HeaderCaseMap::default();
+ orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into());
+ head.extensions.insert(orig_headers);
+
+ let mut vec = Vec::new();
+ Client::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: true,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ assert_eq!(
+ &*vec,
+ b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\n\r\n"
+ .as_ref(),
+ );
+ }
+
+ #[test]
+ fn test_server_encode_connect_method() {
+ let mut head = MessageHead::default();
+
+ let mut vec = Vec::new();
+ let encoder = Server::encode(
+ Encode {
+ head: &mut head,
+ body: None,
+ keep_alive: true,
+ req_method: &mut Some(Method::CONNECT),
+ title_case_headers: false,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ assert!(encoder.is_last());
+ }
+
+ #[test]
+ fn test_server_response_encode_title_case() {
+ use crate::proto::BodyLength;
+ use http::header::HeaderValue;
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+ head.headers
+ .insert("weird--header", HeaderValue::from_static(""));
+
+ let mut vec = Vec::new();
+ Server::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: true,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ let expected_response =
+ b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: application/json\r\nWeird--Header: \r\n";
+
+ assert_eq!(&vec[..expected_response.len()], &expected_response[..]);
+ }
+
+ #[test]
+ fn test_server_response_encode_orig_case() {
+ use crate::proto::BodyLength;
+ use http::header::{HeaderValue, CONTENT_LENGTH};
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+
+ let mut orig_headers = HeaderCaseMap::default();
+ orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into());
+ head.extensions.insert(orig_headers);
+
+ let mut vec = Vec::new();
+ Server::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: false,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ let expected_response =
+ b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\ndate: ";
+
+ assert_eq!(&vec[..expected_response.len()], &expected_response[..]);
+ }
+
+ #[test]
+ fn test_server_response_encode_orig_and_title_case() {
+ use crate::proto::BodyLength;
+ use http::header::{HeaderValue, CONTENT_LENGTH};
+
+ let mut head = MessageHead::default();
+ head.headers
+ .insert("content-length", HeaderValue::from_static("10"));
+ head.headers
+ .insert("content-type", HeaderValue::from_static("application/json"));
+
+ let mut orig_headers = HeaderCaseMap::default();
+ orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into());
+ head.extensions.insert(orig_headers);
+
+ let mut vec = Vec::new();
+ Server::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut None,
+ title_case_headers: true,
+ },
+ &mut vec,
+ )
+ .unwrap();
+
+ let expected_response =
+ b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\nDate: ";
+
+ assert_eq!(&vec[..expected_response.len()], &expected_response[..]);
+ }
+
+ #[test]
+ fn parse_header_htabs() {
+ let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n");
+ let parsed = Client::parse(
+ &mut bytes,
+ ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(Method::GET),
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .expect("parse ok")
+ .expect("parse complete");
+
+ assert_eq!(parsed.head.headers["server"], "hello\tworld");
+ }
+
+ #[test]
+ fn test_write_headers_orig_case_empty_value() {
+ let mut headers = HeaderMap::new();
+ let name = http::header::HeaderName::from_static("x-empty");
+ headers.insert(&name, "".parse().expect("parse empty"));
+ let mut orig_cases = HeaderCaseMap::default();
+ orig_cases.insert(name, Bytes::from_static(b"X-EmptY"));
+
+ let mut dst = Vec::new();
+ super::write_headers_original_case(&headers, &orig_cases, &mut dst, false);
+
+ assert_eq!(
+ dst, b"X-EmptY:\r\n",
+ "there should be no space between the colon and CRLF"
+ );
+ }
+
+ #[test]
+ fn test_write_headers_orig_case_multiple_entries() {
+ let mut headers = HeaderMap::new();
+ let name = http::header::HeaderName::from_static("x-empty");
+ headers.insert(&name, "a".parse().unwrap());
+ headers.append(&name, "b".parse().unwrap());
+
+ let mut orig_cases = HeaderCaseMap::default();
+ orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty"));
+ orig_cases.append(name, Bytes::from_static(b"X-EMPTY"));
+
+ let mut dst = Vec::new();
+ super::write_headers_original_case(&headers, &orig_cases, &mut dst, false);
+
+ assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n");
+ }
+
+ #[cfg(feature = "nightly")]
+ use test::Bencher;
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_parse_incoming(b: &mut Bencher) {
+ let mut raw = BytesMut::from(
+ &b"GET /super_long_uri/and_whatever?what_should_we_talk_about/\
+ I_wonder/Hard_to_write_in_an_uri_after_all/you_have_to_make\
+ _up_the_punctuation_yourself/how_fun_is_that?test=foo&test1=\
+ foo1&test2=foo2&test3=foo3&test4=foo4 HTTP/1.1\r\nHost: \
+ hyper.rs\r\nAccept: a lot of things\r\nAccept-Charset: \
+ utf8\r\nAccept-Encoding: *\r\nAccess-Control-Allow-\
+ Credentials: None\r\nAccess-Control-Allow-Origin: None\r\n\
+ Access-Control-Allow-Methods: None\r\nAccess-Control-Allow-\
+ Headers: None\r\nContent-Encoding: utf8\r\nContent-Security-\
+ Policy: None\r\nContent-Type: text/html\r\nOrigin: hyper\
+ \r\nSec-Websocket-Extensions: It looks super important!\r\n\
+ Sec-Websocket-Origin: hyper\r\nSec-Websocket-Version: 4.3\r\
+ \nStrict-Transport-Security: None\r\nUser-Agent: hyper\r\n\
+ X-Content-Duration: None\r\nX-Content-Security-Policy: None\
+ \r\nX-DNSPrefetch-Control: None\r\nX-Frame-Options: \
+ Something important obviously\r\nX-Requested-With: Nothing\
+ \r\n\r\n"[..],
+ );
+ let len = raw.len();
+ let mut headers = Some(HeaderMap::new());
+
+ b.bytes = len as u64;
+ b.iter(|| {
+ let mut msg = Server::parse(
+ &mut raw,
+ ParseContext {
+ cached_headers: &mut headers,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .unwrap()
+ .unwrap();
+ ::test::black_box(&msg);
+ msg.head.headers.clear();
+ headers = Some(msg.head.headers);
+ restart(&mut raw, len);
+ });
+
+ fn restart(b: &mut BytesMut, len: usize) {
+ b.reserve(1);
+ unsafe {
+ b.set_len(len);
+ }
+ }
+ }
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_parse_short(b: &mut Bencher) {
+ let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..];
+ let mut raw = BytesMut::from(s);
+ let len = raw.len();
+ let mut headers = Some(HeaderMap::new());
+
+ b.bytes = len as u64;
+ b.iter(|| {
+ let mut msg = Server::parse(
+ &mut raw,
+ ParseContext {
+ cached_headers: &mut headers,
+ req_method: &mut None,
+ h1_parser_config: Default::default(),
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout: None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_fut: &mut None,
+ #[cfg(feature = "runtime")]
+ h1_header_read_timeout_running: &mut false,
+ preserve_header_case: false,
+ #[cfg(feature = "ffi")]
+ preserve_header_order: false,
+ h09_responses: false,
+ #[cfg(feature = "ffi")]
+ on_informational: &mut None,
+ #[cfg(feature = "ffi")]
+ raw_headers: false,
+ },
+ )
+ .unwrap()
+ .unwrap();
+ ::test::black_box(&msg);
+ msg.head.headers.clear();
+ headers = Some(msg.head.headers);
+ restart(&mut raw, len);
+ });
+
+ fn restart(b: &mut BytesMut, len: usize) {
+ b.reserve(1);
+ unsafe {
+ b.set_len(len);
+ }
+ }
+ }
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_server_encode_headers_preset(b: &mut Bencher) {
+ use crate::proto::BodyLength;
+ use http::header::HeaderValue;
+
+ let len = 108;
+ b.bytes = len as u64;
+
+ let mut head = MessageHead::default();
+ let mut headers = HeaderMap::new();
+ headers.insert("content-length", HeaderValue::from_static("10"));
+ headers.insert("content-type", HeaderValue::from_static("application/json"));
+
+ b.iter(|| {
+ let mut vec = Vec::new();
+ head.headers = headers.clone();
+ Server::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut Some(Method::GET),
+ title_case_headers: false,
+ },
+ &mut vec,
+ )
+ .unwrap();
+ assert_eq!(vec.len(), len);
+ ::test::black_box(vec);
+ })
+ }
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_server_encode_no_headers(b: &mut Bencher) {
+ use crate::proto::BodyLength;
+
+ let len = 76;
+ b.bytes = len as u64;
+
+ let mut head = MessageHead::default();
+ let mut vec = Vec::with_capacity(128);
+
+ b.iter(|| {
+ Server::encode(
+ Encode {
+ head: &mut head,
+ body: Some(BodyLength::Known(10)),
+ keep_alive: true,
+ req_method: &mut Some(Method::GET),
+ title_case_headers: false,
+ },
+ &mut vec,
+ )
+ .unwrap();
+ assert_eq!(vec.len(), len);
+ ::test::black_box(&vec);
+
+ vec.clear();
+ })
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h2/client.rs b/third_party/rust/hyper/src/proto/h2/client.rs
new file mode 100644
index 0000000000..bac8eceb3a
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h2/client.rs
@@ -0,0 +1,450 @@
+use std::error::Error as StdError;
+#[cfg(feature = "runtime")]
+use std::time::Duration;
+
+use bytes::Bytes;
+use futures_channel::{mpsc, oneshot};
+use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
+use futures_util::stream::StreamExt as _;
+use h2::client::{Builder, SendRequest};
+use h2::SendStream;
+use http::{Method, StatusCode};
+use tokio::io::{AsyncRead, AsyncWrite};
+use tracing::{debug, trace, warn};
+
+use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
+use crate::body::HttpBody;
+use crate::client::dispatch::Callback;
+use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
+use crate::ext::Protocol;
+use crate::headers;
+use crate::proto::h2::UpgradedSendStream;
+use crate::proto::Dispatched;
+use crate::upgrade::Upgraded;
+use crate::{Body, Request, Response};
+use h2::client::ResponseFuture;
+
+type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>;
+
+///// An mpsc channel is used to help notify the `Connection` task when *all*
+///// other handles to it have been dropped, so that it can shutdown.
+type ConnDropRef = mpsc::Sender<Never>;
+
+///// A oneshot channel watches the `Connection` task, and when it completes,
+///// the "dispatch" task will be notified and can shutdown sooner.
+type ConnEof = oneshot::Receiver<Never>;
+
+// Our defaults are chosen for the "majority" case, which usually are not
+// resource constrained, and so the spec default of 64kb can be too limiting
+// for performance.
+const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024 * 5; // 5mb
+const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024 * 2; // 2mb
+const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb
+const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 1024; // 1mb
+
+#[derive(Clone, Debug)]
+pub(crate) struct Config {
+ pub(crate) adaptive_window: bool,
+ pub(crate) initial_conn_window_size: u32,
+ pub(crate) initial_stream_window_size: u32,
+ pub(crate) max_frame_size: u32,
+ #[cfg(feature = "runtime")]
+ pub(crate) keep_alive_interval: Option<Duration>,
+ #[cfg(feature = "runtime")]
+ pub(crate) keep_alive_timeout: Duration,
+ #[cfg(feature = "runtime")]
+ pub(crate) keep_alive_while_idle: bool,
+ pub(crate) max_concurrent_reset_streams: Option<usize>,
+ pub(crate) max_send_buffer_size: usize,
+}
+
+impl Default for Config {
+ fn default() -> Config {
+ Config {
+ adaptive_window: false,
+ initial_conn_window_size: DEFAULT_CONN_WINDOW,
+ initial_stream_window_size: DEFAULT_STREAM_WINDOW,
+ max_frame_size: DEFAULT_MAX_FRAME_SIZE,
+ #[cfg(feature = "runtime")]
+ keep_alive_interval: None,
+ #[cfg(feature = "runtime")]
+ keep_alive_timeout: Duration::from_secs(20),
+ #[cfg(feature = "runtime")]
+ keep_alive_while_idle: false,
+ max_concurrent_reset_streams: None,
+ max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE,
+ }
+ }
+}
+
+fn new_builder(config: &Config) -> Builder {
+ let mut builder = Builder::default();
+ builder
+ .initial_window_size(config.initial_stream_window_size)
+ .initial_connection_window_size(config.initial_conn_window_size)
+ .max_frame_size(config.max_frame_size)
+ .max_send_buffer_size(config.max_send_buffer_size)
+ .enable_push(false);
+ if let Some(max) = config.max_concurrent_reset_streams {
+ builder.max_concurrent_reset_streams(max);
+ }
+ builder
+}
+
+fn new_ping_config(config: &Config) -> ping::Config {
+ ping::Config {
+ bdp_initial_window: if config.adaptive_window {
+ Some(config.initial_stream_window_size)
+ } else {
+ None
+ },
+ #[cfg(feature = "runtime")]
+ keep_alive_interval: config.keep_alive_interval,
+ #[cfg(feature = "runtime")]
+ keep_alive_timeout: config.keep_alive_timeout,
+ #[cfg(feature = "runtime")]
+ keep_alive_while_idle: config.keep_alive_while_idle,
+ }
+}
+
+pub(crate) async fn handshake<T, B>(
+ io: T,
+ req_rx: ClientRx<B>,
+ config: &Config,
+ exec: Exec,
+) -> crate::Result<ClientTask<B>>
+where
+ T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
+ B: HttpBody,
+ B::Data: Send + 'static,
+{
+ let (h2_tx, mut conn) = new_builder(config)
+ .handshake::<_, SendBuf<B::Data>>(io)
+ .await
+ .map_err(crate::Error::new_h2)?;
+
+ // An mpsc channel is used entirely to detect when the
+ // 'Client' has been dropped. This is to get around a bug
+ // in h2 where dropping all SendRequests won't notify a
+ // parked Connection.
+ let (conn_drop_ref, rx) = mpsc::channel(1);
+ let (cancel_tx, conn_eof) = oneshot::channel();
+
+ let conn_drop_rx = rx.into_future().map(|(item, _rx)| {
+ if let Some(never) = item {
+ match never {}
+ }
+ });
+
+ let ping_config = new_ping_config(&config);
+
+ let (conn, ping) = if ping_config.is_enabled() {
+ let pp = conn.ping_pong().expect("conn.ping_pong");
+ let (recorder, mut ponger) = ping::channel(pp, ping_config);
+
+ let conn = future::poll_fn(move |cx| {
+ match ponger.poll(cx) {
+ Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => {
+ conn.set_target_window_size(wnd);
+ conn.set_initial_window_size(wnd)?;
+ }
+ #[cfg(feature = "runtime")]
+ Poll::Ready(ping::Ponged::KeepAliveTimedOut) => {
+ debug!("connection keep-alive timed out");
+ return Poll::Ready(Ok(()));
+ }
+ Poll::Pending => {}
+ }
+
+ Pin::new(&mut conn).poll(cx)
+ });
+ (Either::Left(conn), recorder)
+ } else {
+ (Either::Right(conn), ping::disabled())
+ };
+ let conn = conn.map_err(|e| debug!("connection error: {}", e));
+
+ exec.execute(conn_task(conn, conn_drop_rx, cancel_tx));
+
+ Ok(ClientTask {
+ ping,
+ conn_drop_ref,
+ conn_eof,
+ executor: exec,
+ h2_tx,
+ req_rx,
+ fut_ctx: None,
+ })
+}
+
+async fn conn_task<C, D>(conn: C, drop_rx: D, cancel_tx: oneshot::Sender<Never>)
+where
+ C: Future + Unpin,
+ D: Future<Output = ()> + Unpin,
+{
+ match future::select(conn, drop_rx).await {
+ Either::Left(_) => {
+ // ok or err, the `conn` has finished
+ }
+ Either::Right(((), conn)) => {
+ // mpsc has been dropped, hopefully polling
+ // the connection some more should start shutdown
+ // and then close
+ trace!("send_request dropped, starting conn shutdown");
+ drop(cancel_tx);
+ let _ = conn.await;
+ }
+ }
+}
+
+struct FutCtx<B>
+where
+ B: HttpBody,
+{
+ is_connect: bool,
+ eos: bool,
+ fut: ResponseFuture,
+ body_tx: SendStream<SendBuf<B::Data>>,
+ body: B,
+ cb: Callback<Request<B>, Response<Body>>,
+}
+
+impl<B: HttpBody> Unpin for FutCtx<B> {}
+
+pub(crate) struct ClientTask<B>
+where
+ B: HttpBody,
+{
+ ping: ping::Recorder,
+ conn_drop_ref: ConnDropRef,
+ conn_eof: ConnEof,
+ executor: Exec,
+ h2_tx: SendRequest<SendBuf<B::Data>>,
+ req_rx: ClientRx<B>,
+ fut_ctx: Option<FutCtx<B>>,
+}
+
+impl<B> ClientTask<B>
+where
+ B: HttpBody + 'static,
+{
+ pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
+ self.h2_tx.is_extended_connect_protocol_enabled()
+ }
+}
+
+impl<B> ClientTask<B>
+where
+ B: HttpBody + Send + 'static,
+ B::Data: Send,
+ B::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
+ fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut task::Context<'_>) {
+ let ping = self.ping.clone();
+ let send_stream = if !f.is_connect {
+ if !f.eos {
+ let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| {
+ if let Err(e) = res {
+ debug!("client request body error: {}", e);
+ }
+ });
+
+ // eagerly see if the body pipe is ready and
+ // can thus skip allocating in the executor
+ match Pin::new(&mut pipe).poll(cx) {
+ Poll::Ready(_) => (),
+ Poll::Pending => {
+ let conn_drop_ref = self.conn_drop_ref.clone();
+ // keep the ping recorder's knowledge of an
+ // "open stream" alive while this body is
+ // still sending...
+ let ping = ping.clone();
+ let pipe = pipe.map(move |x| {
+ drop(conn_drop_ref);
+ drop(ping);
+ x
+ });
+ // Clear send task
+ self.executor.execute(pipe);
+ }
+ }
+ }
+
+ None
+ } else {
+ Some(f.body_tx)
+ };
+
+ let fut = f.fut.map(move |result| match result {
+ Ok(res) => {
+ // record that we got the response headers
+ ping.record_non_data();
+
+ let content_length = headers::content_length_parse_all(res.headers());
+ if let (Some(mut send_stream), StatusCode::OK) = (send_stream, res.status()) {
+ if content_length.map_or(false, |len| len != 0) {
+ warn!("h2 connect response with non-zero body not supported");
+
+ send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
+ return Err((
+ crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
+ None,
+ ));
+ }
+ let (parts, recv_stream) = res.into_parts();
+ let mut res = Response::from_parts(parts, Body::empty());
+
+ let (pending, on_upgrade) = crate::upgrade::pending();
+ let io = H2Upgraded {
+ ping,
+ send_stream: unsafe { UpgradedSendStream::new(send_stream) },
+ recv_stream,
+ buf: Bytes::new(),
+ };
+ let upgraded = Upgraded::new(io, Bytes::new());
+
+ pending.fulfill(upgraded);
+ res.extensions_mut().insert(on_upgrade);
+
+ Ok(res)
+ } else {
+ let res = res.map(|stream| {
+ let ping = ping.for_stream(&stream);
+ crate::Body::h2(stream, content_length.into(), ping)
+ });
+ Ok(res)
+ }
+ }
+ Err(err) => {
+ ping.ensure_not_timed_out().map_err(|e| (e, None))?;
+
+ debug!("client response error: {}", err);
+ Err((crate::Error::new_h2(err), None))
+ }
+ });
+ self.executor.execute(f.cb.send_when(fut));
+ }
+}
+
+impl<B> Future for ClientTask<B>
+where
+ B: HttpBody + Send + 'static,
+ B::Data: Send,
+ B::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
+ type Output = crate::Result<Dispatched>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ loop {
+ match ready!(self.h2_tx.poll_ready(cx)) {
+ Ok(()) => (),
+ Err(err) => {
+ self.ping.ensure_not_timed_out()?;
+ return if err.reason() == Some(::h2::Reason::NO_ERROR) {
+ trace!("connection gracefully shutdown");
+ Poll::Ready(Ok(Dispatched::Shutdown))
+ } else {
+ Poll::Ready(Err(crate::Error::new_h2(err)))
+ };
+ }
+ };
+
+ match self.fut_ctx.take() {
+ // If we were waiting on pending open
+ // continue where we left off.
+ Some(f) => {
+ self.poll_pipe(f, cx);
+ continue;
+ }
+ None => (),
+ }
+
+ match self.req_rx.poll_recv(cx) {
+ Poll::Ready(Some((req, cb))) => {
+ // check that future hasn't been canceled already
+ if cb.is_canceled() {
+ trace!("request callback is canceled");
+ continue;
+ }
+ let (head, body) = req.into_parts();
+ let mut req = ::http::Request::from_parts(head, ());
+ super::strip_connection_headers(req.headers_mut(), true);
+ if let Some(len) = body.size_hint().exact() {
+ if len != 0 || headers::method_has_defined_payload_semantics(req.method()) {
+ headers::set_content_length_if_missing(req.headers_mut(), len);
+ }
+ }
+
+ let is_connect = req.method() == Method::CONNECT;
+ let eos = body.is_end_stream();
+
+ if is_connect {
+ if headers::content_length_parse_all(req.headers())
+ .map_or(false, |len| len != 0)
+ {
+ warn!("h2 connect request with non-zero body not supported");
+ cb.send(Err((
+ crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
+ None,
+ )));
+ continue;
+ }
+ }
+
+ if let Some(protocol) = req.extensions_mut().remove::<Protocol>() {
+ req.extensions_mut().insert(protocol.into_inner());
+ }
+
+ let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) {
+ Ok(ok) => ok,
+ Err(err) => {
+ debug!("client send request error: {}", err);
+ cb.send(Err((crate::Error::new_h2(err), None)));
+ continue;
+ }
+ };
+
+ let f = FutCtx {
+ is_connect,
+ eos,
+ fut,
+ body_tx,
+ body,
+ cb,
+ };
+
+ // Check poll_ready() again.
+ // If the call to send_request() resulted in the new stream being pending open
+ // we have to wait for the open to complete before accepting new requests.
+ match self.h2_tx.poll_ready(cx) {
+ Poll::Pending => {
+ // Save Context
+ self.fut_ctx = Some(f);
+ return Poll::Pending;
+ }
+ Poll::Ready(Ok(())) => (),
+ Poll::Ready(Err(err)) => {
+ f.cb.send(Err((crate::Error::new_h2(err), None)));
+ continue;
+ }
+ }
+ self.poll_pipe(f, cx);
+ continue;
+ }
+
+ Poll::Ready(None) => {
+ trace!("client::dispatch::Sender dropped");
+ return Poll::Ready(Ok(Dispatched::Shutdown));
+ }
+
+ Poll::Pending => match ready!(Pin::new(&mut self.conn_eof).poll(cx)) {
+ Ok(never) => match never {},
+ Err(_conn_is_eof) => {
+ trace!("connection task is closed, closing dispatch task");
+ return Poll::Ready(Ok(Dispatched::Shutdown));
+ }
+ },
+ }
+ }
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h2/mod.rs b/third_party/rust/hyper/src/proto/h2/mod.rs
new file mode 100644
index 0000000000..5857c919d1
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h2/mod.rs
@@ -0,0 +1,471 @@
+use bytes::{Buf, Bytes};
+use h2::{Reason, RecvStream, SendStream};
+use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE};
+use http::HeaderMap;
+use pin_project_lite::pin_project;
+use std::error::Error as StdError;
+use std::io::{self, Cursor, IoSlice};
+use std::mem;
+use std::task::Context;
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tracing::{debug, trace, warn};
+
+use crate::body::HttpBody;
+use crate::common::{task, Future, Pin, Poll};
+use crate::proto::h2::ping::Recorder;
+
+pub(crate) mod ping;
+
+cfg_client! {
+ pub(crate) mod client;
+ pub(crate) use self::client::ClientTask;
+}
+
+cfg_server! {
+ pub(crate) mod server;
+ pub(crate) use self::server::Server;
+}
+
+/// Default initial stream window size defined in HTTP2 spec.
+pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535;
+
+fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
+ // List of connection headers from:
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection
+ //
+ // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're
+ // tested separately.
+ let connection_headers = [
+ HeaderName::from_lowercase(b"keep-alive").unwrap(),
+ HeaderName::from_lowercase(b"proxy-connection").unwrap(),
+ TRAILER,
+ TRANSFER_ENCODING,
+ UPGRADE,
+ ];
+
+ for header in connection_headers.iter() {
+ if headers.remove(header).is_some() {
+ warn!("Connection header illegal in HTTP/2: {}", header.as_str());
+ }
+ }
+
+ if is_request {
+ if headers
+ .get(TE)
+ .map(|te_header| te_header != "trailers")
+ .unwrap_or(false)
+ {
+ warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests");
+ headers.remove(TE);
+ }
+ } else if headers.remove(TE).is_some() {
+ warn!("TE headers illegal in HTTP/2 responses");
+ }
+
+ if let Some(header) = headers.remove(CONNECTION) {
+ warn!(
+ "Connection header illegal in HTTP/2: {}",
+ CONNECTION.as_str()
+ );
+ let header_contents = header.to_str().unwrap();
+
+ // A `Connection` header may have a comma-separated list of names of other headers that
+ // are meant for only this specific connection.
+ //
+ // Iterate these names and remove them as headers. Connection-specific headers are
+ // forbidden in HTTP2, as that information has been moved into frame types of the h2
+ // protocol.
+ for name in header_contents.split(',') {
+ let name = name.trim();
+ headers.remove(name);
+ }
+ }
+}
+
+// body adapters used by both Client and Server
+
+pin_project! {
+ struct PipeToSendStream<S>
+ where
+ S: HttpBody,
+ {
+ body_tx: SendStream<SendBuf<S::Data>>,
+ data_done: bool,
+ #[pin]
+ stream: S,
+ }
+}
+
+impl<S> PipeToSendStream<S>
+where
+ S: HttpBody,
+{
+ fn new(stream: S, tx: SendStream<SendBuf<S::Data>>) -> PipeToSendStream<S> {
+ PipeToSendStream {
+ body_tx: tx,
+ data_done: false,
+ stream,
+ }
+ }
+}
+
+impl<S> Future for PipeToSendStream<S>
+where
+ S: HttpBody,
+ S::Error: Into<Box<dyn StdError + Send + Sync>>,
+{
+ type Output = crate::Result<()>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ let mut me = self.project();
+ loop {
+ if !*me.data_done {
+ // we don't have the next chunk of data yet, so just reserve 1 byte to make
+ // sure there's some capacity available. h2 will handle the capacity management
+ // for the actual body chunk.
+ me.body_tx.reserve_capacity(1);
+
+ if me.body_tx.capacity() == 0 {
+ loop {
+ match ready!(me.body_tx.poll_capacity(cx)) {
+ Some(Ok(0)) => {}
+ Some(Ok(_)) => break,
+ Some(Err(e)) => {
+ return Poll::Ready(Err(crate::Error::new_body_write(e)))
+ }
+ None => {
+ // None means the stream is no longer in a
+ // streaming state, we either finished it
+ // somehow, or the remote reset us.
+ return Poll::Ready(Err(crate::Error::new_body_write(
+ "send stream capacity unexpectedly closed",
+ )));
+ }
+ }
+ }
+ } else if let Poll::Ready(reason) = me
+ .body_tx
+ .poll_reset(cx)
+ .map_err(crate::Error::new_body_write)?
+ {
+ debug!("stream received RST_STREAM: {:?}", reason);
+ return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
+ reason,
+ ))));
+ }
+
+ match ready!(me.stream.as_mut().poll_data(cx)) {
+ Some(Ok(chunk)) => {
+ let is_eos = me.stream.is_end_stream();
+ trace!(
+ "send body chunk: {} bytes, eos={}",
+ chunk.remaining(),
+ is_eos,
+ );
+
+ let buf = SendBuf::Buf(chunk);
+ me.body_tx
+ .send_data(buf, is_eos)
+ .map_err(crate::Error::new_body_write)?;
+
+ if is_eos {
+ return Poll::Ready(Ok(()));
+ }
+ }
+ Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))),
+ None => {
+ me.body_tx.reserve_capacity(0);
+ let is_eos = me.stream.is_end_stream();
+ if is_eos {
+ return Poll::Ready(me.body_tx.send_eos_frame());
+ } else {
+ *me.data_done = true;
+ // loop again to poll_trailers
+ }
+ }
+ }
+ } else {
+ if let Poll::Ready(reason) = me
+ .body_tx
+ .poll_reset(cx)
+ .map_err(crate::Error::new_body_write)?
+ {
+ debug!("stream received RST_STREAM: {:?}", reason);
+ return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
+ reason,
+ ))));
+ }
+
+ match ready!(me.stream.poll_trailers(cx)) {
+ Ok(Some(trailers)) => {
+ me.body_tx
+ .send_trailers(trailers)
+ .map_err(crate::Error::new_body_write)?;
+ return Poll::Ready(Ok(()));
+ }
+ Ok(None) => {
+ // There were no trailers, so send an empty DATA frame...
+ return Poll::Ready(me.body_tx.send_eos_frame());
+ }
+ Err(e) => return Poll::Ready(Err(me.body_tx.on_user_err(e))),
+ }
+ }
+ }
+ }
+}
+
+trait SendStreamExt {
+ fn on_user_err<E>(&mut self, err: E) -> crate::Error
+ where
+ E: Into<Box<dyn std::error::Error + Send + Sync>>;
+ fn send_eos_frame(&mut self) -> crate::Result<()>;
+}
+
+impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
+ fn on_user_err<E>(&mut self, err: E) -> crate::Error
+ where
+ E: Into<Box<dyn std::error::Error + Send + Sync>>,
+ {
+ let err = crate::Error::new_user_body(err);
+ debug!("send body user stream error: {}", err);
+ self.send_reset(err.h2_reason());
+ err
+ }
+
+ fn send_eos_frame(&mut self) -> crate::Result<()> {
+ trace!("send body eos");
+ self.send_data(SendBuf::None, true)
+ .map_err(crate::Error::new_body_write)
+ }
+}
+
+#[repr(usize)]
+enum SendBuf<B> {
+ Buf(B),
+ Cursor(Cursor<Box<[u8]>>),
+ None,
+}
+
+impl<B: Buf> Buf for SendBuf<B> {
+ #[inline]
+ fn remaining(&self) -> usize {
+ match *self {
+ Self::Buf(ref b) => b.remaining(),
+ Self::Cursor(ref c) => Buf::remaining(c),
+ Self::None => 0,
+ }
+ }
+
+ #[inline]
+ fn chunk(&self) -> &[u8] {
+ match *self {
+ Self::Buf(ref b) => b.chunk(),
+ Self::Cursor(ref c) => c.chunk(),
+ Self::None => &[],
+ }
+ }
+
+ #[inline]
+ fn advance(&mut self, cnt: usize) {
+ match *self {
+ Self::Buf(ref mut b) => b.advance(cnt),
+ Self::Cursor(ref mut c) => c.advance(cnt),
+ Self::None => {}
+ }
+ }
+
+ fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
+ match *self {
+ Self::Buf(ref b) => b.chunks_vectored(dst),
+ Self::Cursor(ref c) => c.chunks_vectored(dst),
+ Self::None => 0,
+ }
+ }
+}
+
+struct H2Upgraded<B>
+where
+ B: Buf,
+{
+ ping: Recorder,
+ send_stream: UpgradedSendStream<B>,
+ recv_stream: RecvStream,
+ buf: Bytes,
+}
+
+impl<B> AsyncRead for H2Upgraded<B>
+where
+ B: Buf,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ read_buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ if self.buf.is_empty() {
+ self.buf = loop {
+ match ready!(self.recv_stream.poll_data(cx)) {
+ None => return Poll::Ready(Ok(())),
+ Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => {
+ continue
+ }
+ Some(Ok(buf)) => {
+ self.ping.record_data(buf.len());
+ break buf;
+ }
+ Some(Err(e)) => {
+ return Poll::Ready(match e.reason() {
+ Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()),
+ Some(Reason::STREAM_CLOSED) => {
+ Err(io::Error::new(io::ErrorKind::BrokenPipe, e))
+ }
+ _ => Err(h2_to_io_error(e)),
+ })
+ }
+ }
+ };
+ }
+ let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
+ read_buf.put_slice(&self.buf[..cnt]);
+ self.buf.advance(cnt);
+ let _ = self.recv_stream.flow_control().release_capacity(cnt);
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<B> AsyncWrite for H2Upgraded<B>
+where
+ B: Buf,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ if buf.is_empty() {
+ return Poll::Ready(Ok(0));
+ }
+ self.send_stream.reserve_capacity(buf.len());
+
+ // We ignore all errors returned by `poll_capacity` and `write`, as we
+ // will get the correct from `poll_reset` anyway.
+ let cnt = match ready!(self.send_stream.poll_capacity(cx)) {
+ None => Some(0),
+ Some(Ok(cnt)) => self
+ .send_stream
+ .write(&buf[..cnt], false)
+ .ok()
+ .map(|()| cnt),
+ Some(Err(_)) => None,
+ };
+
+ if let Some(cnt) = cnt {
+ return Poll::Ready(Ok(cnt));
+ }
+
+ Poll::Ready(Err(h2_to_io_error(
+ match ready!(self.send_stream.poll_reset(cx)) {
+ Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
+ return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
+ }
+ Ok(reason) => reason.into(),
+ Err(e) => e,
+ },
+ )))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ if self.send_stream.write(&[], true).is_ok() {
+ return Poll::Ready(Ok(()))
+ }
+
+ Poll::Ready(Err(h2_to_io_error(
+ match ready!(self.send_stream.poll_reset(cx)) {
+ Ok(Reason::NO_ERROR) => {
+ return Poll::Ready(Ok(()))
+ }
+ Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
+ return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
+ }
+ Ok(reason) => reason.into(),
+ Err(e) => e,
+ },
+ )))
+ }
+}
+
+fn h2_to_io_error(e: h2::Error) -> io::Error {
+ if e.is_io() {
+ e.into_io().unwrap()
+ } else {
+ io::Error::new(io::ErrorKind::Other, e)
+ }
+}
+
+struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>);
+
+impl<B> UpgradedSendStream<B>
+where
+ B: Buf,
+{
+ unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self {
+ assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>());
+ Self(mem::transmute(inner))
+ }
+
+ fn reserve_capacity(&mut self, cnt: usize) {
+ unsafe { self.as_inner_unchecked().reserve_capacity(cnt) }
+ }
+
+ fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> {
+ unsafe { self.as_inner_unchecked().poll_capacity(cx) }
+ }
+
+ fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> {
+ unsafe { self.as_inner_unchecked().poll_reset(cx) }
+ }
+
+ fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
+ let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
+ unsafe {
+ self.as_inner_unchecked()
+ .send_data(send_buf, end_of_stream)
+ .map_err(h2_to_io_error)
+ }
+ }
+
+ unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> {
+ &mut *(&mut self.0 as *mut _ as *mut _)
+ }
+}
+
+#[repr(transparent)]
+struct Neutered<B> {
+ _inner: B,
+ impossible: Impossible,
+}
+
+enum Impossible {}
+
+unsafe impl<B> Send for Neutered<B> {}
+
+impl<B> Buf for Neutered<B> {
+ fn remaining(&self) -> usize {
+ match self.impossible {}
+ }
+
+ fn chunk(&self) -> &[u8] {
+ match self.impossible {}
+ }
+
+ fn advance(&mut self, _cnt: usize) {
+ match self.impossible {}
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h2/ping.rs b/third_party/rust/hyper/src/proto/h2/ping.rs
new file mode 100644
index 0000000000..1e8386497c
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h2/ping.rs
@@ -0,0 +1,555 @@
+/// HTTP2 Ping usage
+///
+/// hyper uses HTTP2 pings for two purposes:
+///
+/// 1. Adaptive flow control using BDP
+/// 2. Connection keep-alive
+///
+/// Both cases are optional.
+///
+/// # BDP Algorithm
+///
+/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding:
+/// 1a. Record current time.
+/// 1b. Send a BDP ping.
+/// 2. Increment the number of received bytes.
+/// 3. When the BDP ping ack is received:
+/// 3a. Record duration from sent time.
+/// 3b. Merge RTT with a running average.
+/// 3c. Calculate bdp as bytes/rtt.
+/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows.
+
+#[cfg(feature = "runtime")]
+use std::fmt;
+#[cfg(feature = "runtime")]
+use std::future::Future;
+#[cfg(feature = "runtime")]
+use std::pin::Pin;
+use std::sync::{Arc, Mutex};
+use std::task::{self, Poll};
+use std::time::Duration;
+#[cfg(not(feature = "runtime"))]
+use std::time::Instant;
+
+use h2::{Ping, PingPong};
+#[cfg(feature = "runtime")]
+use tokio::time::{Instant, Sleep};
+use tracing::{debug, trace};
+
+type WindowSize = u32;
+
+pub(super) fn disabled() -> Recorder {
+ Recorder { shared: None }
+}
+
+pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) {
+ debug_assert!(
+ config.is_enabled(),
+ "ping channel requires bdp or keep-alive config",
+ );
+
+ let bdp = config.bdp_initial_window.map(|wnd| Bdp {
+ bdp: wnd,
+ max_bandwidth: 0.0,
+ rtt: 0.0,
+ ping_delay: Duration::from_millis(100),
+ stable_count: 0,
+ });
+
+ let (bytes, next_bdp_at) = if bdp.is_some() {
+ (Some(0), Some(Instant::now()))
+ } else {
+ (None, None)
+ };
+
+ #[cfg(feature = "runtime")]
+ let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive {
+ interval,
+ timeout: config.keep_alive_timeout,
+ while_idle: config.keep_alive_while_idle,
+ timer: Box::pin(tokio::time::sleep(interval)),
+ state: KeepAliveState::Init,
+ });
+
+ #[cfg(feature = "runtime")]
+ let last_read_at = keep_alive.as_ref().map(|_| Instant::now());
+
+ let shared = Arc::new(Mutex::new(Shared {
+ bytes,
+ #[cfg(feature = "runtime")]
+ last_read_at,
+ #[cfg(feature = "runtime")]
+ is_keep_alive_timed_out: false,
+ ping_pong,
+ ping_sent_at: None,
+ next_bdp_at,
+ }));
+
+ (
+ Recorder {
+ shared: Some(shared.clone()),
+ },
+ Ponger {
+ bdp,
+ #[cfg(feature = "runtime")]
+ keep_alive,
+ shared,
+ },
+ )
+}
+
+#[derive(Clone)]
+pub(super) struct Config {
+ pub(super) bdp_initial_window: Option<WindowSize>,
+ /// If no frames are received in this amount of time, a PING frame is sent.
+ #[cfg(feature = "runtime")]
+ pub(super) keep_alive_interval: Option<Duration>,
+ /// After sending a keepalive PING, the connection will be closed if
+ /// a pong is not received in this amount of time.
+ #[cfg(feature = "runtime")]
+ pub(super) keep_alive_timeout: Duration,
+ /// If true, sends pings even when there are no active streams.
+ #[cfg(feature = "runtime")]
+ pub(super) keep_alive_while_idle: bool,
+}
+
+#[derive(Clone)]
+pub(crate) struct Recorder {
+ shared: Option<Arc<Mutex<Shared>>>,
+}
+
+pub(super) struct Ponger {
+ bdp: Option<Bdp>,
+ #[cfg(feature = "runtime")]
+ keep_alive: Option<KeepAlive>,
+ shared: Arc<Mutex<Shared>>,
+}
+
+struct Shared {
+ ping_pong: PingPong,
+ ping_sent_at: Option<Instant>,
+
+ // bdp
+ /// If `Some`, bdp is enabled, and this tracks how many bytes have been
+ /// read during the current sample.
+ bytes: Option<usize>,
+ /// We delay a variable amount of time between BDP pings. This allows us
+ /// to send less pings as the bandwidth stabilizes.
+ next_bdp_at: Option<Instant>,
+
+ // keep-alive
+ /// If `Some`, keep-alive is enabled, and the Instant is how long ago
+ /// the connection read the last frame.
+ #[cfg(feature = "runtime")]
+ last_read_at: Option<Instant>,
+
+ #[cfg(feature = "runtime")]
+ is_keep_alive_timed_out: bool,
+}
+
+struct Bdp {
+ /// Current BDP in bytes
+ bdp: u32,
+ /// Largest bandwidth we've seen so far.
+ max_bandwidth: f64,
+ /// Round trip time in seconds
+ rtt: f64,
+ /// Delay the next ping by this amount.
+ ///
+ /// This will change depending on how stable the current bandwidth is.
+ ping_delay: Duration,
+ /// The count of ping round trips where BDP has stayed the same.
+ stable_count: u32,
+}
+
+#[cfg(feature = "runtime")]
+struct KeepAlive {
+ /// If no frames are received in this amount of time, a PING frame is sent.
+ interval: Duration,
+ /// After sending a keepalive PING, the connection will be closed if
+ /// a pong is not received in this amount of time.
+ timeout: Duration,
+ /// If true, sends pings even when there are no active streams.
+ while_idle: bool,
+
+ state: KeepAliveState,
+ timer: Pin<Box<Sleep>>,
+}
+
+#[cfg(feature = "runtime")]
+enum KeepAliveState {
+ Init,
+ Scheduled,
+ PingSent,
+}
+
+pub(super) enum Ponged {
+ SizeUpdate(WindowSize),
+ #[cfg(feature = "runtime")]
+ KeepAliveTimedOut,
+}
+
+#[cfg(feature = "runtime")]
+#[derive(Debug)]
+pub(super) struct KeepAliveTimedOut;
+
+// ===== impl Config =====
+
+impl Config {
+ pub(super) fn is_enabled(&self) -> bool {
+ #[cfg(feature = "runtime")]
+ {
+ self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some()
+ }
+
+ #[cfg(not(feature = "runtime"))]
+ {
+ self.bdp_initial_window.is_some()
+ }
+ }
+}
+
+// ===== impl Recorder =====
+
+impl Recorder {
+ pub(crate) fn record_data(&self, len: usize) {
+ let shared = if let Some(ref shared) = self.shared {
+ shared
+ } else {
+ return;
+ };
+
+ let mut locked = shared.lock().unwrap();
+
+ #[cfg(feature = "runtime")]
+ locked.update_last_read_at();
+
+ // are we ready to send another bdp ping?
+ // if not, we don't need to record bytes either
+
+ if let Some(ref next_bdp_at) = locked.next_bdp_at {
+ if Instant::now() < *next_bdp_at {
+ return;
+ } else {
+ locked.next_bdp_at = None;
+ }
+ }
+
+ if let Some(ref mut bytes) = locked.bytes {
+ *bytes += len;
+ } else {
+ // no need to send bdp ping if bdp is disabled
+ return;
+ }
+
+ if !locked.is_ping_sent() {
+ locked.send_ping();
+ }
+ }
+
+ pub(crate) fn record_non_data(&self) {
+ #[cfg(feature = "runtime")]
+ {
+ let shared = if let Some(ref shared) = self.shared {
+ shared
+ } else {
+ return;
+ };
+
+ let mut locked = shared.lock().unwrap();
+
+ locked.update_last_read_at();
+ }
+ }
+
+ /// If the incoming stream is already closed, convert self into
+ /// a disabled reporter.
+ #[cfg(feature = "client")]
+ pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self {
+ if stream.is_end_stream() {
+ disabled()
+ } else {
+ self
+ }
+ }
+
+ pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> {
+ #[cfg(feature = "runtime")]
+ {
+ if let Some(ref shared) = self.shared {
+ let locked = shared.lock().unwrap();
+ if locked.is_keep_alive_timed_out {
+ return Err(KeepAliveTimedOut.crate_error());
+ }
+ }
+ }
+
+ // else
+ Ok(())
+ }
+}
+
+// ===== impl Ponger =====
+
+impl Ponger {
+ pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<Ponged> {
+ let now = Instant::now();
+ let mut locked = self.shared.lock().unwrap();
+ #[cfg(feature = "runtime")]
+ let is_idle = self.is_idle();
+
+ #[cfg(feature = "runtime")]
+ {
+ if let Some(ref mut ka) = self.keep_alive {
+ ka.schedule(is_idle, &locked);
+ ka.maybe_ping(cx, &mut locked);
+ }
+ }
+
+ if !locked.is_ping_sent() {
+ // XXX: this doesn't register a waker...?
+ return Poll::Pending;
+ }
+
+ match locked.ping_pong.poll_pong(cx) {
+ Poll::Ready(Ok(_pong)) => {
+ let start = locked
+ .ping_sent_at
+ .expect("pong received implies ping_sent_at");
+ locked.ping_sent_at = None;
+ let rtt = now - start;
+ trace!("recv pong");
+
+ #[cfg(feature = "runtime")]
+ {
+ if let Some(ref mut ka) = self.keep_alive {
+ locked.update_last_read_at();
+ ka.schedule(is_idle, &locked);
+ }
+ }
+
+ if let Some(ref mut bdp) = self.bdp {
+ let bytes = locked.bytes.expect("bdp enabled implies bytes");
+ locked.bytes = Some(0); // reset
+ trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt);
+
+ let update = bdp.calculate(bytes, rtt);
+ locked.next_bdp_at = Some(now + bdp.ping_delay);
+ if let Some(update) = update {
+ return Poll::Ready(Ponged::SizeUpdate(update))
+ }
+ }
+ }
+ Poll::Ready(Err(e)) => {
+ debug!("pong error: {}", e);
+ }
+ Poll::Pending => {
+ #[cfg(feature = "runtime")]
+ {
+ if let Some(ref mut ka) = self.keep_alive {
+ if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) {
+ self.keep_alive = None;
+ locked.is_keep_alive_timed_out = true;
+ return Poll::Ready(Ponged::KeepAliveTimedOut);
+ }
+ }
+ }
+ }
+ }
+
+ // XXX: this doesn't register a waker...?
+ Poll::Pending
+ }
+
+ #[cfg(feature = "runtime")]
+ fn is_idle(&self) -> bool {
+ Arc::strong_count(&self.shared) <= 2
+ }
+}
+
+// ===== impl Shared =====
+
+impl Shared {
+ fn send_ping(&mut self) {
+ match self.ping_pong.send_ping(Ping::opaque()) {
+ Ok(()) => {
+ self.ping_sent_at = Some(Instant::now());
+ trace!("sent ping");
+ }
+ Err(err) => {
+ debug!("error sending ping: {}", err);
+ }
+ }
+ }
+
+ fn is_ping_sent(&self) -> bool {
+ self.ping_sent_at.is_some()
+ }
+
+ #[cfg(feature = "runtime")]
+ fn update_last_read_at(&mut self) {
+ if self.last_read_at.is_some() {
+ self.last_read_at = Some(Instant::now());
+ }
+ }
+
+ #[cfg(feature = "runtime")]
+ fn last_read_at(&self) -> Instant {
+ self.last_read_at.expect("keep_alive expects last_read_at")
+ }
+}
+
+// ===== impl Bdp =====
+
+/// Any higher than this likely will be hitting the TCP flow control.
+const BDP_LIMIT: usize = 1024 * 1024 * 16;
+
+impl Bdp {
+ fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option<WindowSize> {
+ // No need to do any math if we're at the limit.
+ if self.bdp as usize == BDP_LIMIT {
+ self.stabilize_delay();
+ return None;
+ }
+
+ // average the rtt
+ let rtt = seconds(rtt);
+ if self.rtt == 0.0 {
+ // First sample means rtt is first rtt.
+ self.rtt = rtt;
+ } else {
+ // Weigh this rtt as 1/8 for a moving average.
+ self.rtt += (rtt - self.rtt) * 0.125;
+ }
+
+ // calculate the current bandwidth
+ let bw = (bytes as f64) / (self.rtt * 1.5);
+ trace!("current bandwidth = {:.1}B/s", bw);
+
+ if bw < self.max_bandwidth {
+ // not a faster bandwidth, so don't update
+ self.stabilize_delay();
+ return None;
+ } else {
+ self.max_bandwidth = bw;
+ }
+
+ // if the current `bytes` sample is at least 2/3 the previous
+ // bdp, increase to double the current sample.
+ if bytes >= self.bdp as usize * 2 / 3 {
+ self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize;
+ trace!("BDP increased to {}", self.bdp);
+
+ self.stable_count = 0;
+ self.ping_delay /= 2;
+ Some(self.bdp)
+ } else {
+ self.stabilize_delay();
+ None
+ }
+ }
+
+ fn stabilize_delay(&mut self) {
+ if self.ping_delay < Duration::from_secs(10) {
+ self.stable_count += 1;
+
+ if self.stable_count >= 2 {
+ self.ping_delay *= 4;
+ self.stable_count = 0;
+ }
+ }
+ }
+}
+
+fn seconds(dur: Duration) -> f64 {
+ const NANOS_PER_SEC: f64 = 1_000_000_000.0;
+ let secs = dur.as_secs() as f64;
+ secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC
+}
+
+// ===== impl KeepAlive =====
+
+#[cfg(feature = "runtime")]
+impl KeepAlive {
+ fn schedule(&mut self, is_idle: bool, shared: &Shared) {
+ match self.state {
+ KeepAliveState::Init => {
+ if !self.while_idle && is_idle {
+ return;
+ }
+
+ self.state = KeepAliveState::Scheduled;
+ let interval = shared.last_read_at() + self.interval;
+ self.timer.as_mut().reset(interval);
+ }
+ KeepAliveState::PingSent => {
+ if shared.is_ping_sent() {
+ return;
+ }
+
+ self.state = KeepAliveState::Scheduled;
+ let interval = shared.last_read_at() + self.interval;
+ self.timer.as_mut().reset(interval);
+ }
+ KeepAliveState::Scheduled => (),
+ }
+ }
+
+ fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) {
+ match self.state {
+ KeepAliveState::Scheduled => {
+ if Pin::new(&mut self.timer).poll(cx).is_pending() {
+ return;
+ }
+ // check if we've received a frame while we were scheduled
+ if shared.last_read_at() + self.interval > self.timer.deadline() {
+ self.state = KeepAliveState::Init;
+ cx.waker().wake_by_ref(); // schedule us again
+ return;
+ }
+ trace!("keep-alive interval ({:?}) reached", self.interval);
+ shared.send_ping();
+ self.state = KeepAliveState::PingSent;
+ let timeout = Instant::now() + self.timeout;
+ self.timer.as_mut().reset(timeout);
+ }
+ KeepAliveState::Init | KeepAliveState::PingSent => (),
+ }
+ }
+
+ fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> {
+ match self.state {
+ KeepAliveState::PingSent => {
+ if Pin::new(&mut self.timer).poll(cx).is_pending() {
+ return Ok(());
+ }
+ trace!("keep-alive timeout ({:?}) reached", self.timeout);
+ Err(KeepAliveTimedOut)
+ }
+ KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()),
+ }
+ }
+}
+
+// ===== impl KeepAliveTimedOut =====
+
+#[cfg(feature = "runtime")]
+impl KeepAliveTimedOut {
+ pub(super) fn crate_error(self) -> crate::Error {
+ crate::Error::new(crate::error::Kind::Http2).with(self)
+ }
+}
+
+#[cfg(feature = "runtime")]
+impl fmt::Display for KeepAliveTimedOut {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("keep-alive timed out")
+ }
+}
+
+#[cfg(feature = "runtime")]
+impl std::error::Error for KeepAliveTimedOut {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ Some(&crate::error::TimedOut)
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/h2/server.rs b/third_party/rust/hyper/src/proto/h2/server.rs
new file mode 100644
index 0000000000..d24e6bac5f
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/h2/server.rs
@@ -0,0 +1,548 @@
+use std::error::Error as StdError;
+use std::marker::Unpin;
+#[cfg(feature = "runtime")]
+use std::time::Duration;
+
+use bytes::Bytes;
+use h2::server::{Connection, Handshake, SendResponse};
+use h2::{Reason, RecvStream};
+use http::{Method, Request};
+use pin_project_lite::pin_project;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tracing::{debug, trace, warn};
+
+use super::{ping, PipeToSendStream, SendBuf};
+use crate::body::HttpBody;
+use crate::common::exec::ConnStreamExec;
+use crate::common::{date, task, Future, Pin, Poll};
+use crate::ext::Protocol;
+use crate::headers;
+use crate::proto::h2::ping::Recorder;
+use crate::proto::h2::{H2Upgraded, UpgradedSendStream};
+use crate::proto::Dispatched;
+use crate::service::HttpService;
+
+use crate::upgrade::{OnUpgrade, Pending, Upgraded};
+use crate::{Body, Response};
+
+// Our defaults are chosen for the "majority" case, which usually are not
+// resource constrained, and so the spec default of 64kb can be too limiting
+// for performance.
+//
+// At the same time, a server more often has multiple clients connected, and
+// so is more likely to use more resources than a client would.
+const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb
+const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb
+const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb
+const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 400; // 400kb
+// 16 MB "sane default" taken from golang http2
+const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: u32 = 16 << 20;
+
+#[derive(Clone, Debug)]
+pub(crate) struct Config {
+ pub(crate) adaptive_window: bool,
+ pub(crate) initial_conn_window_size: u32,
+ pub(crate) initial_stream_window_size: u32,
+ pub(crate) max_frame_size: u32,
+ pub(crate) enable_connect_protocol: bool,
+ pub(crate) max_concurrent_streams: Option<u32>,
+ #[cfg(feature = "runtime")]
+ pub(crate) keep_alive_interval: Option<Duration>,
+ #[cfg(feature = "runtime")]
+ pub(crate) keep_alive_timeout: Duration,
+ pub(crate) max_send_buffer_size: usize,
+ pub(crate) max_header_list_size: u32,
+}
+
+impl Default for Config {
+ fn default() -> Config {
+ Config {
+ adaptive_window: false,
+ initial_conn_window_size: DEFAULT_CONN_WINDOW,
+ initial_stream_window_size: DEFAULT_STREAM_WINDOW,
+ max_frame_size: DEFAULT_MAX_FRAME_SIZE,
+ enable_connect_protocol: false,
+ max_concurrent_streams: None,
+ #[cfg(feature = "runtime")]
+ keep_alive_interval: None,
+ #[cfg(feature = "runtime")]
+ keep_alive_timeout: Duration::from_secs(20),
+ max_send_buffer_size: DEFAULT_MAX_SEND_BUF_SIZE,
+ max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
+ }
+ }
+}
+
+pin_project! {
+ pub(crate) struct Server<T, S, B, E>
+ where
+ S: HttpService<Body>,
+ B: HttpBody,
+ {
+ exec: E,
+ service: S,
+ state: State<T, B>,
+ }
+}
+
+enum State<T, B>
+where
+ B: HttpBody,
+{
+ Handshaking {
+ ping_config: ping::Config,
+ hs: Handshake<T, SendBuf<B::Data>>,
+ },
+ Serving(Serving<T, B>),
+ Closed,
+}
+
+struct Serving<T, B>
+where
+ B: HttpBody,
+{
+ ping: Option<(ping::Recorder, ping::Ponger)>,
+ conn: Connection<T, SendBuf<B::Data>>,
+ closing: Option<crate::Error>,
+}
+
+impl<T, S, B, E> Server<T, S, B, E>
+where
+ T: AsyncRead + AsyncWrite + Unpin,
+ S: HttpService<Body, ResBody = B>,
+ S::Error: Into<Box<dyn StdError + Send + Sync>>,
+ B: HttpBody + 'static,
+ E: ConnStreamExec<S::Future, B>,
+{
+ pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server<T, S, B, E> {
+ let mut builder = h2::server::Builder::default();
+ builder
+ .initial_window_size(config.initial_stream_window_size)
+ .initial_connection_window_size(config.initial_conn_window_size)
+ .max_frame_size(config.max_frame_size)
+ .max_header_list_size(config.max_header_list_size)
+ .max_send_buffer_size(config.max_send_buffer_size);
+ if let Some(max) = config.max_concurrent_streams {
+ builder.max_concurrent_streams(max);
+ }
+ if config.enable_connect_protocol {
+ builder.enable_connect_protocol();
+ }
+ let handshake = builder.handshake(io);
+
+ let bdp = if config.adaptive_window {
+ Some(config.initial_stream_window_size)
+ } else {
+ None
+ };
+
+ let ping_config = ping::Config {
+ bdp_initial_window: bdp,
+ #[cfg(feature = "runtime")]
+ keep_alive_interval: config.keep_alive_interval,
+ #[cfg(feature = "runtime")]
+ keep_alive_timeout: config.keep_alive_timeout,
+ // If keep-alive is enabled for servers, always enabled while
+ // idle, so it can more aggressively close dead connections.
+ #[cfg(feature = "runtime")]
+ keep_alive_while_idle: true,
+ };
+
+ Server {
+ exec,
+ state: State::Handshaking {
+ ping_config,
+ hs: handshake,
+ },
+ service,
+ }
+ }
+
+ pub(crate) fn graceful_shutdown(&mut self) {
+ trace!("graceful_shutdown");
+ match self.state {
+ State::Handshaking { .. } => {
+ // fall-through, to replace state with Closed
+ }
+ State::Serving(ref mut srv) => {
+ if srv.closing.is_none() {
+ srv.conn.graceful_shutdown();
+ }
+ return;
+ }
+ State::Closed => {
+ return;
+ }
+ }
+ self.state = State::Closed;
+ }
+}
+
+impl<T, S, B, E> Future for Server<T, S, B, E>
+where
+ T: AsyncRead + AsyncWrite + Unpin,
+ S: HttpService<Body, ResBody = B>,
+ S::Error: Into<Box<dyn StdError + Send + Sync>>,
+ B: HttpBody + 'static,
+ E: ConnStreamExec<S::Future, B>,
+{
+ type Output = crate::Result<Dispatched>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ let me = &mut *self;
+ loop {
+ let next = match me.state {
+ State::Handshaking {
+ ref mut hs,
+ ref ping_config,
+ } => {
+ let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?;
+ let ping = if ping_config.is_enabled() {
+ let pp = conn.ping_pong().expect("conn.ping_pong");
+ Some(ping::channel(pp, ping_config.clone()))
+ } else {
+ None
+ };
+ State::Serving(Serving {
+ ping,
+ conn,
+ closing: None,
+ })
+ }
+ State::Serving(ref mut srv) => {
+ ready!(srv.poll_server(cx, &mut me.service, &mut me.exec))?;
+ return Poll::Ready(Ok(Dispatched::Shutdown));
+ }
+ State::Closed => {
+ // graceful_shutdown was called before handshaking finished,
+ // nothing to do here...
+ return Poll::Ready(Ok(Dispatched::Shutdown));
+ }
+ };
+ me.state = next;
+ }
+ }
+}
+
+impl<T, B> Serving<T, B>
+where
+ T: AsyncRead + AsyncWrite + Unpin,
+ B: HttpBody + 'static,
+{
+ fn poll_server<S, E>(
+ &mut self,
+ cx: &mut task::Context<'_>,
+ service: &mut S,
+ exec: &mut E,
+ ) -> Poll<crate::Result<()>>
+ where
+ S: HttpService<Body, ResBody = B>,
+ S::Error: Into<Box<dyn StdError + Send + Sync>>,
+ E: ConnStreamExec<S::Future, B>,
+ {
+ if self.closing.is_none() {
+ loop {
+ self.poll_ping(cx);
+
+ // Check that the service is ready to accept a new request.
+ //
+ // - If not, just drive the connection some.
+ // - If ready, try to accept a new request from the connection.
+ match service.poll_ready(cx) {
+ Poll::Ready(Ok(())) => (),
+ Poll::Pending => {
+ // use `poll_closed` instead of `poll_accept`,
+ // in order to avoid accepting a request.
+ ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?;
+ trace!("incoming connection complete");
+ return Poll::Ready(Ok(()));
+ }
+ Poll::Ready(Err(err)) => {
+ let err = crate::Error::new_user_service(err);
+ debug!("service closed: {}", err);
+
+ let reason = err.h2_reason();
+ if reason == Reason::NO_ERROR {
+ // NO_ERROR is only used for graceful shutdowns...
+ trace!("interpreting NO_ERROR user error as graceful_shutdown");
+ self.conn.graceful_shutdown();
+ } else {
+ trace!("abruptly shutting down with {:?}", reason);
+ self.conn.abrupt_shutdown(reason);
+ }
+ self.closing = Some(err);
+ break;
+ }
+ }
+
+ // When the service is ready, accepts an incoming request.
+ match ready!(self.conn.poll_accept(cx)) {
+ Some(Ok((req, mut respond))) => {
+ trace!("incoming request");
+ let content_length = headers::content_length_parse_all(req.headers());
+ let ping = self
+ .ping
+ .as_ref()
+ .map(|ping| ping.0.clone())
+ .unwrap_or_else(ping::disabled);
+
+ // Record the headers received
+ ping.record_non_data();
+
+ let is_connect = req.method() == Method::CONNECT;
+ let (mut parts, stream) = req.into_parts();
+ let (mut req, connect_parts) = if !is_connect {
+ (
+ Request::from_parts(
+ parts,
+ crate::Body::h2(stream, content_length.into(), ping),
+ ),
+ None,
+ )
+ } else {
+ if content_length.map_or(false, |len| len != 0) {
+ warn!("h2 connect request with non-zero body not supported");
+ respond.send_reset(h2::Reason::INTERNAL_ERROR);
+ return Poll::Ready(Ok(()));
+ }
+ let (pending, upgrade) = crate::upgrade::pending();
+ debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
+ parts.extensions.insert(upgrade);
+ (
+ Request::from_parts(parts, crate::Body::empty()),
+ Some(ConnectParts {
+ pending,
+ ping,
+ recv_stream: stream,
+ }),
+ )
+ };
+
+ if let Some(protocol) = req.extensions_mut().remove::<h2::ext::Protocol>() {
+ req.extensions_mut().insert(Protocol::from_inner(protocol));
+ }
+
+ let fut = H2Stream::new(service.call(req), connect_parts, respond);
+ exec.execute_h2stream(fut);
+ }
+ Some(Err(e)) => {
+ return Poll::Ready(Err(crate::Error::new_h2(e)));
+ }
+ None => {
+ // no more incoming streams...
+ if let Some((ref ping, _)) = self.ping {
+ ping.ensure_not_timed_out()?;
+ }
+
+ trace!("incoming connection complete");
+ return Poll::Ready(Ok(()));
+ }
+ }
+ }
+ }
+
+ debug_assert!(
+ self.closing.is_some(),
+ "poll_server broke loop without closing"
+ );
+
+ ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?;
+
+ Poll::Ready(Err(self.closing.take().expect("polled after error")))
+ }
+
+ fn poll_ping(&mut self, cx: &mut task::Context<'_>) {
+ if let Some((_, ref mut estimator)) = self.ping {
+ match estimator.poll(cx) {
+ Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => {
+ self.conn.set_target_window_size(wnd);
+ let _ = self.conn.set_initial_window_size(wnd);
+ }
+ #[cfg(feature = "runtime")]
+ Poll::Ready(ping::Ponged::KeepAliveTimedOut) => {
+ debug!("keep-alive timed out, closing connection");
+ self.conn.abrupt_shutdown(h2::Reason::NO_ERROR);
+ }
+ Poll::Pending => {}
+ }
+ }
+ }
+}
+
+pin_project! {
+ #[allow(missing_debug_implementations)]
+ pub struct H2Stream<F, B>
+ where
+ B: HttpBody,
+ {
+ reply: SendResponse<SendBuf<B::Data>>,
+ #[pin]
+ state: H2StreamState<F, B>,
+ }
+}
+
+pin_project! {
+ #[project = H2StreamStateProj]
+ enum H2StreamState<F, B>
+ where
+ B: HttpBody,
+ {
+ Service {
+ #[pin]
+ fut: F,
+ connect_parts: Option<ConnectParts>,
+ },
+ Body {
+ #[pin]
+ pipe: PipeToSendStream<B>,
+ },
+ }
+}
+
+struct ConnectParts {
+ pending: Pending,
+ ping: Recorder,
+ recv_stream: RecvStream,
+}
+
+impl<F, B> H2Stream<F, B>
+where
+ B: HttpBody,
+{
+ fn new(
+ fut: F,
+ connect_parts: Option<ConnectParts>,
+ respond: SendResponse<SendBuf<B::Data>>,
+ ) -> H2Stream<F, B> {
+ H2Stream {
+ reply: respond,
+ state: H2StreamState::Service { fut, connect_parts },
+ }
+ }
+}
+
+macro_rules! reply {
+ ($me:expr, $res:expr, $eos:expr) => {{
+ match $me.reply.send_response($res, $eos) {
+ Ok(tx) => tx,
+ Err(e) => {
+ debug!("send response error: {}", e);
+ $me.reply.send_reset(Reason::INTERNAL_ERROR);
+ return Poll::Ready(Err(crate::Error::new_h2(e)));
+ }
+ }
+ }};
+}
+
+impl<F, B, E> H2Stream<F, B>
+where
+ F: Future<Output = Result<Response<B>, E>>,
+ B: HttpBody,
+ B::Data: 'static,
+ B::Error: Into<Box<dyn StdError + Send + Sync>>,
+ E: Into<Box<dyn StdError + Send + Sync>>,
+{
+ fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
+ let mut me = self.project();
+ loop {
+ let next = match me.state.as_mut().project() {
+ H2StreamStateProj::Service {
+ fut: h,
+ connect_parts,
+ } => {
+ let res = match h.poll(cx) {
+ Poll::Ready(Ok(r)) => r,
+ Poll::Pending => {
+ // Response is not yet ready, so we want to check if the client has sent a
+ // RST_STREAM frame which would cancel the current request.
+ if let Poll::Ready(reason) =
+ me.reply.poll_reset(cx).map_err(crate::Error::new_h2)?
+ {
+ debug!("stream received RST_STREAM: {:?}", reason);
+ return Poll::Ready(Err(crate::Error::new_h2(reason.into())));
+ }
+ return Poll::Pending;
+ }
+ Poll::Ready(Err(e)) => {
+ let err = crate::Error::new_user_service(e);
+ warn!("http2 service errored: {}", err);
+ me.reply.send_reset(err.h2_reason());
+ return Poll::Ready(Err(err));
+ }
+ };
+
+ let (head, body) = res.into_parts();
+ let mut res = ::http::Response::from_parts(head, ());
+ super::strip_connection_headers(res.headers_mut(), false);
+
+ // set Date header if it isn't already set...
+ res.headers_mut()
+ .entry(::http::header::DATE)
+ .or_insert_with(date::update_and_header_value);
+
+ if let Some(connect_parts) = connect_parts.take() {
+ if res.status().is_success() {
+ if headers::content_length_parse_all(res.headers())
+ .map_or(false, |len| len != 0)
+ {
+ warn!("h2 successful response to CONNECT request with body not supported");
+ me.reply.send_reset(h2::Reason::INTERNAL_ERROR);
+ return Poll::Ready(Err(crate::Error::new_user_header()));
+ }
+ let send_stream = reply!(me, res, false);
+ connect_parts.pending.fulfill(Upgraded::new(
+ H2Upgraded {
+ ping: connect_parts.ping,
+ recv_stream: connect_parts.recv_stream,
+ send_stream: unsafe { UpgradedSendStream::new(send_stream) },
+ buf: Bytes::new(),
+ },
+ Bytes::new(),
+ ));
+ return Poll::Ready(Ok(()));
+ }
+ }
+
+
+ if !body.is_end_stream() {
+ // automatically set Content-Length from body...
+ if let Some(len) = body.size_hint().exact() {
+ headers::set_content_length_if_missing(res.headers_mut(), len);
+ }
+
+ let body_tx = reply!(me, res, false);
+ H2StreamState::Body {
+ pipe: PipeToSendStream::new(body, body_tx),
+ }
+ } else {
+ reply!(me, res, true);
+ return Poll::Ready(Ok(()));
+ }
+ }
+ H2StreamStateProj::Body { pipe } => {
+ return pipe.poll(cx);
+ }
+ };
+ me.state.set(next);
+ }
+ }
+}
+
+impl<F, B, E> Future for H2Stream<F, B>
+where
+ F: Future<Output = Result<Response<B>, E>>,
+ B: HttpBody,
+ B::Data: 'static,
+ B::Error: Into<Box<dyn StdError + Send + Sync>>,
+ E: Into<Box<dyn StdError + Send + Sync>>,
+{
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ self.poll2(cx).map(|res| {
+ if let Err(e) = res {
+ debug!("stream error: {}", e);
+ }
+ })
+ }
+}
diff --git a/third_party/rust/hyper/src/proto/mod.rs b/third_party/rust/hyper/src/proto/mod.rs
new file mode 100644
index 0000000000..f938bf532b
--- /dev/null
+++ b/third_party/rust/hyper/src/proto/mod.rs
@@ -0,0 +1,71 @@
+//! Pieces pertaining to the HTTP message protocol.
+
+cfg_feature! {
+ #![feature = "http1"]
+
+ pub(crate) mod h1;
+
+ pub(crate) use self::h1::Conn;
+
+ #[cfg(feature = "client")]
+ pub(crate) use self::h1::dispatch;
+ #[cfg(feature = "server")]
+ pub(crate) use self::h1::ServerTransaction;
+}
+
+#[cfg(feature = "http2")]
+pub(crate) mod h2;
+
+/// An Incoming Message head. Includes request/status line, and headers.
+#[derive(Debug, Default)]
+pub(crate) struct MessageHead<S> {
+ /// HTTP version of the message.
+ pub(crate) version: http::Version,
+ /// Subject (request line or status line) of Incoming message.
+ pub(crate) subject: S,
+ /// Headers of the Incoming message.
+ pub(crate) headers: http::HeaderMap,
+ /// Extensions.
+ extensions: http::Extensions,
+}
+
+/// An incoming request message.
+#[cfg(feature = "http1")]
+pub(crate) type RequestHead = MessageHead<RequestLine>;
+
+#[derive(Debug, Default, PartialEq)]
+#[cfg(feature = "http1")]
+pub(crate) struct RequestLine(pub(crate) http::Method, pub(crate) http::Uri);
+
+/// An incoming response message.
+#[cfg(all(feature = "http1", feature = "client"))]
+pub(crate) type ResponseHead = MessageHead<http::StatusCode>;
+
+#[derive(Debug)]
+#[cfg(feature = "http1")]
+pub(crate) enum BodyLength {
+ /// Content-Length
+ Known(u64),
+ /// Transfer-Encoding: chunked (if h1)
+ Unknown,
+}
+
+/// Status of when a Disaptcher future completes.
+pub(crate) enum Dispatched {
+ /// Dispatcher completely shutdown connection.
+ Shutdown,
+ /// Dispatcher has pending upgrade, and so did not shutdown.
+ #[cfg(feature = "http1")]
+ Upgrade(crate::upgrade::Pending),
+}
+
+impl MessageHead<http::StatusCode> {
+ fn into_response<B>(self, body: B) -> http::Response<B> {
+ let mut res = http::Response::new(body);
+ *res.status_mut() = self.subject;
+ *res.headers_mut() = self.headers;
+ *res.version_mut() = self.version;
+ *res.extensions_mut() = self.extensions;
+ res
+ }
+}