diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/hyper/src/proto/h1 | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/hyper/src/proto/h1')
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/conn.rs | 1425 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/decode.rs | 731 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/dispatch.rs | 750 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/encode.rs | 439 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/io.rs | 1002 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/mod.rs | 122 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/role.rs | 2847 |
7 files changed, 7316 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/proto/h1/conn.rs b/third_party/rust/hyper/src/proto/h1/conn.rs new file mode 100644 index 0000000000..5ebff2803e --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/conn.rs @@ -0,0 +1,1425 @@ +use std::fmt; +use std::io; +use std::marker::PhantomData; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Duration; + +use bytes::{Buf, Bytes}; +use http::header::{HeaderValue, CONNECTION}; +use http::{HeaderMap, Method, Version}; +use httparse::ParserConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; +use tracing::{debug, error, trace}; + +use super::io::Buffered; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; +use crate::body::DecodedLength; +use crate::common::{task, Pin, Poll, Unpin}; +use crate::headers::connection_keep_alive; +use crate::proto::{BodyLength, MessageHead}; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// This handles a connection, which will have been established over an +/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Transaction`s over HTTP. +/// +/// The connection will determine when a message begins and ends as well as +/// determine if this connection can be kept alive after the message, +/// or if it is complete. +pub(crate) struct Conn<I, B, T> { + io: Buffered<I, EncodedBuf<B>>, + state: State, + _marker: PhantomData<fn(T)>, +} + +impl<I, B, T> Conn<I, B, T> +where + I: AsyncRead + AsyncWrite + Unpin, + B: Buf, + T: Http1Transaction, +{ + pub(crate) fn new(io: I) -> Conn<I, B, T> { + Conn { + io: Buffered::new(io), + state: State { + allow_half_close: false, + cached_headers: None, + error: None, + keep_alive: KA::Busy, + method: None, + h1_parser_config: ParserConfig::default(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + title_case_headers: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: None, + #[cfg(feature = "ffi")] + raw_headers: false, + notify_read: false, + reading: Reading::Init, + writing: Writing::Init, + upgrade: None, + // We assume a modern world where the remote speaks HTTP/1.1. + // If they tell us otherwise, we'll downgrade in `read_head`. + version: Version::HTTP_11, + }, + _marker: PhantomData, + } + } + + #[cfg(feature = "server")] + pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { + self.io.set_flush_pipeline(enabled); + } + + pub(crate) fn set_write_strategy_queue(&mut self) { + self.io.set_write_strategy_queue(); + } + + pub(crate) fn set_max_buf_size(&mut self, max: usize) { + self.io.set_max_buf_size(max); + } + + #[cfg(feature = "client")] + pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { + self.io.set_read_buf_exact_size(sz); + } + + pub(crate) fn set_write_strategy_flatten(&mut self) { + self.io.set_write_strategy_flatten(); + } + + #[cfg(feature = "client")] + pub(crate) fn set_h1_parser_config(&mut self, parser_config: ParserConfig) { + self.state.h1_parser_config = parser_config; + } + + pub(crate) fn set_title_case_headers(&mut self) { + self.state.title_case_headers = true; + } + + pub(crate) fn set_preserve_header_case(&mut self) { + self.state.preserve_header_case = true; + } + + #[cfg(feature = "ffi")] + pub(crate) fn set_preserve_header_order(&mut self) { + self.state.preserve_header_order = true; + } + + #[cfg(feature = "client")] + pub(crate) fn set_h09_responses(&mut self) { + self.state.h09_responses = true; + } + + #[cfg(all(feature = "server", feature = "runtime"))] + pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) { + self.state.h1_header_read_timeout = Some(val); + } + + #[cfg(feature = "server")] + pub(crate) fn set_allow_half_close(&mut self) { + self.state.allow_half_close = true; + } + + #[cfg(feature = "ffi")] + pub(crate) fn set_raw_headers(&mut self, enabled: bool) { + self.state.raw_headers = enabled; + } + + pub(crate) fn into_inner(self) -> (I, Bytes) { + self.io.into_inner() + } + + pub(crate) fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> { + self.state.upgrade.take() + } + + pub(crate) fn is_read_closed(&self) -> bool { + self.state.is_read_closed() + } + + pub(crate) fn is_write_closed(&self) -> bool { + self.state.is_write_closed() + } + + pub(crate) fn can_read_head(&self) -> bool { + if !matches!(self.state.reading, Reading::Init) { + return false; + } + + if T::should_read_first() { + return true; + } + + !matches!(self.state.writing, Writing::Init) + } + + pub(crate) fn can_read_body(&self) -> bool { + match self.state.reading { + Reading::Body(..) | Reading::Continue(..) => true, + _ => false, + } + } + + fn should_error_on_eof(&self) -> bool { + // If we're idle, it's probably just the connection closing gracefully. + T::should_error_on_parse_eof() && !self.state.is_idle() + } + + fn has_h2_prefix(&self) -> bool { + let read_buf = self.io.read_buf(); + read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE + } + + pub(super) fn poll_read_head( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> { + debug_assert!(self.can_read_head()); + trace!("Conn::read_head"); + + let msg = match ready!(self.io.parse::<T>( + cx, + ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + h1_parser_config: self.state.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: self.state.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running, + preserve_header_case: self.state.preserve_header_case, + #[cfg(feature = "ffi")] + preserve_header_order: self.state.preserve_header_order, + h09_responses: self.state.h09_responses, + #[cfg(feature = "ffi")] + on_informational: &mut self.state.on_informational, + #[cfg(feature = "ffi")] + raw_headers: self.state.raw_headers, + } + )) { + Ok(msg) => msg, + Err(e) => return self.on_read_head_error(e), + }; + + // Note: don't deconstruct `msg` into local variables, it appears + // the optimizer doesn't remove the extra copies. + + debug!("incoming body is {}", msg.decode); + + // Prevent accepting HTTP/0.9 responses after the initial one, if any. + self.state.h09_responses = false; + + // Drop any OnInformational callbacks, we're done there! + #[cfg(feature = "ffi")] + { + self.state.on_informational = None; + } + + self.state.busy(); + self.state.keep_alive &= msg.keep_alive; + self.state.version = msg.head.version; + + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + + if msg.decode == DecodedLength::ZERO { + if msg.expect_continue { + debug!("ignoring expect-continue since body is empty"); + } + self.state.reading = Reading::KeepAlive; + if !T::should_read_first() { + self.try_keep_alive(cx); + } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); + } else { + self.state.reading = Reading::Body(Decoder::new(msg.decode)); + } + + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) + } + + fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); + if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface + debug!( + "parse error ({}) with {} bytes", + e, + self.io.read_buf().len() + ); + match self.on_parse_error(e) { + Ok(()) => Poll::Pending, // XXX: wat? + Err(e) => Poll::Ready(Some(Err(e))), + } + } else { + debug!("read eof"); + self.close_write(); + Poll::Ready(None) + } + } + + pub(crate) fn poll_read_body( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<io::Result<Bytes>>> { + debug_assert!(self.can_read_body()); + + let (reading, ret) = match self.state.reading { + Reading::Body(ref mut decoder) => { + match ready!(decoder.decode(cx, &mut self.io)) { + Ok(slice) => { + let (reading, chunk) = if decoder.is_eof() { + debug!("incoming body completed"); + ( + Reading::KeepAlive, + if !slice.is_empty() { + Some(Ok(slice)) + } else { + None + }, + ) + } else if slice.is_empty() { + error!("incoming body unexpectedly ended"); + // This should be unreachable, since all 3 decoders + // either set eof=true or return an Err when reading + // an empty slice... + (Reading::Closed, None) + } else { + return Poll::Ready(Some(Ok(slice))); + }; + (reading, Poll::Ready(chunk)) + } + Err(e) => { + debug!("incoming body decode error: {}", e); + (Reading::Closed, Poll::Ready(Some(Err(e)))) + } + } + } + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), + }; + + self.state.reading = reading; + self.try_keep_alive(cx); + ret + } + + pub(crate) fn wants_read_again(&mut self) -> bool { + let ret = self.state.notify_read; + self.state.notify_read = false; + ret + } + + pub(crate) fn poll_read_keep_alive( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body()); + + if self.is_read_closed() { + Poll::Pending + } else if self.is_mid_message() { + self.mid_message_detect_eof(cx) + } else { + self.require_empty_read(cx) + } + } + + fn is_mid_message(&self) -> bool { + !matches!( + (&self.state.reading, &self.state.writing), + (&Reading::Init, &Writing::Init) + ) + } + + // This will check to make sure the io object read is empty. + // + // This should only be called for Clients wanting to enter the idle + // state. + fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(!self.is_mid_message()); + debug_assert!(T::is_client()); + + if !self.io.read_buf().is_empty() { + debug!("received an unexpected {} bytes", self.io.read_buf().len()); + return Poll::Ready(Err(crate::Error::new_unexpected_message())); + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + let ret = if self.should_error_on_eof() { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + trace!("found EOF on idle connection, closing"); + Poll::Ready(Ok(())) + }; + + // order is important: should_error needs state BEFORE close_read + self.state.close_read(); + return ret; + } + + debug!( + "received unexpected {} bytes on an idle connection", + num_read + ); + Poll::Ready(Err(crate::Error::new_unexpected_message())) + } + + fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(self.is_mid_message()); + + if self.state.allow_half_close || !self.io.read_buf().is_empty() { + return Poll::Pending; + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + self.state.close_read(); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + Poll::Ready(Ok(())) + } + } + + fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + debug_assert!(!self.state.is_read_closed()); + + let result = ready!(self.io.poll_read_from_io(cx)); + Poll::Ready(result.map_err(|e| { + trace!("force_io_read; io error = {:?}", e); + self.state.close(); + e + })) + } + + fn maybe_notify(&mut self, cx: &mut task::Context<'_>) { + // its possible that we returned NotReady from poll() without having + // exhausted the underlying Io. We would have done this when we + // determined we couldn't keep reading until we knew how writing + // would finish. + + match self.state.reading { + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } + Reading::Init => (), + }; + + match self.state.writing { + Writing::Body(..) => return, + Writing::Init | Writing::KeepAlive | Writing::Closed => (), + } + + if !self.io.is_read_blocked() { + if self.io.read_buf().is_empty() { + match self.io.poll_read_from_io(cx) { + Poll::Ready(Ok(n)) => { + if n == 0 { + trace!("maybe_notify; read eof"); + if self.state.is_idle() { + self.state.close(); + } else { + self.close_read() + } + return; + } + } + Poll::Pending => { + trace!("maybe_notify; read_from_io blocked"); + return; + } + Poll::Ready(Err(e)) => { + trace!("maybe_notify; read_from_io error: {}", e); + self.state.close(); + self.state.error = Some(crate::Error::new_io(e)); + } + } + } + self.state.notify_read = true; + } + } + + fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) { + self.state.try_keep_alive::<T>(); + self.maybe_notify(cx); + } + + pub(crate) fn can_write_head(&self) -> bool { + if !T::should_read_first() && matches!(self.state.reading, Reading::Closed) { + return false; + } + + match self.state.writing { + Writing::Init => self.io.can_headers_buf(), + _ => false, + } + } + + pub(crate) fn can_write_body(&self) -> bool { + match self.state.writing { + Writing::Body(..) => true, + Writing::Init | Writing::KeepAlive | Writing::Closed => false, + } + } + + pub(crate) fn can_buffer_body(&self) -> bool { + self.io.can_buffer() + } + + pub(crate) fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) { + if let Some(encoder) = self.encode_head(head, body) { + self.state.writing = if !encoder.is_eof() { + Writing::Body(encoder) + } else if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + + pub(crate) fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) { + if let Some(encoder) = + self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64))) + { + let is_last = encoder.is_last(); + // Make sure we don't write a body if we weren't actually allowed + // to do so, like because its a HEAD request. + if !encoder.is_eof() { + encoder.danger_full_buf(body, self.io.write_buf()); + } + self.state.writing = if is_last { + Writing::Closed + } else { + Writing::KeepAlive + } + } + } + + fn encode_head( + &mut self, + mut head: MessageHead<T::Outgoing>, + body: Option<BodyLength>, + ) -> Option<Encoder> { + debug_assert!(self.can_write_head()); + + if !T::should_read_first() { + self.state.busy(); + } + + self.enforce_version(&mut head); + + let buf = self.io.headers_buf(); + match super::role::encode_headers::<T>( + Encode { + head: &mut head, + body, + #[cfg(feature = "server")] + keep_alive: self.state.wants_keep_alive(), + req_method: &mut self.state.method, + title_case_headers: self.state.title_case_headers, + }, + buf, + ) { + Ok(encoder) => { + debug_assert!(self.state.cached_headers.is_none()); + debug_assert!(head.headers.is_empty()); + self.state.cached_headers = Some(head.headers); + + #[cfg(feature = "ffi")] + { + self.state.on_informational = + head.extensions.remove::<crate::ffi::OnInformational>(); + } + + Some(encoder) + } + Err(err) => { + self.state.error = Some(err); + self.state.writing = Writing::Closed; + None + } + } + } + + // Fix keep-alive when Connection: keep-alive header is not present + fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) { + let outgoing_is_keep_alive = head + .headers + .get(CONNECTION) + .map(connection_keep_alive) + .unwrap_or(false); + + if !outgoing_is_keep_alive { + match head.version { + // If response is version 1.0 and keep-alive is not present in the response, + // disable keep-alive so the server closes the connection + Version::HTTP_10 => self.state.disable_keep_alive(), + // If response is version 1.1 and keep-alive is wanted, add + // Connection: keep-alive header when not present + Version::HTTP_11 => { + if self.state.wants_keep_alive() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } + _ => (), + } + } + } + + // If we know the remote speaks an older version, we try to fix up any messages + // to work with our older peer. + fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) { + if let Version::HTTP_10 = self.state.version { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = Version::HTTP_10; + } + // If the remote speaks HTTP/1.1, then it *should* be fine with + // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let + // the user's headers be. + } + + pub(crate) fn write_body(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + self.io.buffer(encoder.encode(chunk)); + + if !encoder.is_eof() { + return; + } + + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub(crate) fn write_body_and_end(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref encoder) => { + let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf()); + if can_keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub(crate) fn end_body(&mut self) -> crate::Result<()> { + debug_assert!(self.can_write_body()); + + let encoder = match self.state.writing { + Writing::Body(ref mut enc) => enc, + _ => return Ok(()), + }; + + // end of stream, that means we should try to eof + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + + self.state.writing = if encoder.is_last() || encoder.is_close_delimited() { + Writing::Closed + } else { + Writing::KeepAlive + }; + + Ok(()) + } + Err(not_eof) => { + self.state.writing = Writing::Closed; + Err(crate::Error::new_body_write_aborted().with(not_eof)) + } + } + } + + // When we get a parse error, depending on what side we are, we might be able + // to write a response before closing the connection. + // + // - Client: there is nothing we can do + // - Server: if Response hasn't been written yet, we can send a 4xx response + fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> { + if let Writing::Init = self.state.writing { + if self.has_h2_prefix() { + return Err(crate::Error::new_version_h2()); + } + if let Some(msg) = T::on_error(&err) { + // Drop the cached headers so as to not trigger a debug + // assert in `write_head`... + self.state.cached_headers.take(); + self.write_head(msg, None); + self.state.error = Some(err); + return Ok(()); + } + } + + // fallback is pass the error back up + Err(err) + } + + pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + ready!(Pin::new(&mut self.io).poll_flush(cx))?; + self.try_keep_alive(cx); + trace!("flushed({}): {:?}", T::LOG, self.state); + Poll::Ready(Ok(())) + } + + pub(crate) fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) { + Ok(()) => { + trace!("shut down IO complete"); + Poll::Ready(Ok(())) + } + Err(e) => { + debug!("error shutting down IO: {}", e); + Poll::Ready(Err(e)) + } + } + } + + /// If the read side can be cheaply drained, do so. Otherwise, close. + pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) { + if let Reading::Continue(ref decoder) = self.state.reading { + // skip sending the 100-continue + // just move forward to a read, in case a tiny body was included + self.state.reading = Reading::Body(decoder.clone()); + } + + let _ = self.poll_read_body(cx); + + // If still in Reading::Body, just give up + match self.state.reading { + Reading::Init | Reading::KeepAlive => trace!("body drained"), + _ => self.close_read(), + } + } + + pub(crate) fn close_read(&mut self) { + self.state.close_read(); + } + + pub(crate) fn close_write(&mut self) { + self.state.close_write(); + } + + #[cfg(feature = "server")] + pub(crate) fn disable_keep_alive(&mut self) { + if self.state.is_idle() { + trace!("disable_keep_alive; closing idle connection"); + self.state.close(); + } else { + trace!("disable_keep_alive; in-progress connection"); + self.state.disable_keep_alive(); + } + } + + pub(crate) fn take_error(&mut self) -> crate::Result<()> { + if let Some(err) = self.state.error.take() { + Err(err) + } else { + Ok(()) + } + } + + pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + trace!("{}: prepare possible HTTP upgrade", T::LOG); + self.state.prepare_upgrade() + } +} + +impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("state", &self.state) + .field("io", &self.io) + .finish() + } +} + +// B and T are never pinned +impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {} + +struct State { + allow_half_close: bool, + /// Re-usable HeaderMap to reduce allocating new ones. + cached_headers: Option<HeaderMap>, + /// If an error occurs when there wasn't a direct way to return it + /// back to the user, this is set. + error: Option<crate::Error>, + /// Current keep-alive status. + keep_alive: KA, + /// If mid-message, the HTTP Method that started it. + /// + /// This is used to know things such as if the message can include + /// a body or not. + method: Option<Method>, + h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option<Duration>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: Option<Pin<Box<Sleep>>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: bool, + preserve_header_case: bool, + #[cfg(feature = "ffi")] + preserve_header_order: bool, + title_case_headers: bool, + h09_responses: bool, + /// If set, called with each 1xx informational response received for + /// the current request. MUST be unset after a non-1xx response is + /// received. + #[cfg(feature = "ffi")] + on_informational: Option<crate::ffi::OnInformational>, + #[cfg(feature = "ffi")] + raw_headers: bool, + /// Set to true when the Dispatcher should poll read operations + /// again. See the `maybe_notify` method for more. + notify_read: bool, + /// State of allowed reads + reading: Reading, + /// State of allowed writes + writing: Writing, + /// An expected pending HTTP upgrade. + upgrade: Option<crate::upgrade::Pending>, + /// Either HTTP/1.0 or 1.1 connection + version: Version, +} + +#[derive(Debug)] +enum Reading { + Init, + Continue(Decoder), + Body(Decoder), + KeepAlive, + Closed, +} + +enum Writing { + Init, + Body(Encoder), + KeepAlive, + Closed, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("State"); + builder + .field("reading", &self.reading) + .field("writing", &self.writing) + .field("keep_alive", &self.keep_alive); + + // Only show error field if it's interesting... + if let Some(ref error) = self.error { + builder.field("error", error); + } + + if self.allow_half_close { + builder.field("allow_half_close", &true); + } + + // Purposefully leaving off other fields.. + + builder.finish() + } +} + +impl fmt::Debug for Writing { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Writing::Init => f.write_str("Init"), + Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(), + Writing::KeepAlive => f.write_str("KeepAlive"), + Writing::Closed => f.write_str("Closed"), + } + } +} + +impl std::ops::BitAndAssign<bool> for KA { + fn bitand_assign(&mut self, enabled: bool) { + if !enabled { + trace!("remote disabling keep-alive"); + *self = KA::Disabled; + } + } +} + +#[derive(Clone, Copy, Debug)] +enum KA { + Idle, + Busy, + Disabled, +} + +impl Default for KA { + fn default() -> KA { + KA::Busy + } +} + +impl KA { + fn idle(&mut self) { + *self = KA::Idle; + } + + fn busy(&mut self) { + *self = KA::Busy; + } + + fn disable(&mut self) { + *self = KA::Disabled; + } + + fn status(&self) -> KA { + *self + } +} + +impl State { + fn close(&mut self) { + trace!("State::close()"); + self.reading = Reading::Closed; + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn close_read(&mut self) { + trace!("State::close_read()"); + self.reading = Reading::Closed; + self.keep_alive.disable(); + } + + fn close_write(&mut self) { + trace!("State::close_write()"); + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn wants_keep_alive(&self) -> bool { + if let KA::Disabled = self.keep_alive.status() { + false + } else { + true + } + } + + fn try_keep_alive<T: Http1Transaction>(&mut self) { + match (&self.reading, &self.writing) { + (&Reading::KeepAlive, &Writing::KeepAlive) => { + if let KA::Busy = self.keep_alive.status() { + self.idle::<T>(); + } else { + trace!( + "try_keep_alive({}): could keep-alive, but status = {:?}", + T::LOG, + self.keep_alive + ); + self.close(); + } + } + (&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => { + self.close() + } + _ => (), + } + } + + fn disable_keep_alive(&mut self) { + self.keep_alive.disable() + } + + fn busy(&mut self) { + if let KA::Disabled = self.keep_alive.status() { + return; + } + self.keep_alive.busy(); + } + + fn idle<T: Http1Transaction>(&mut self) { + debug_assert!(!self.is_idle(), "State::idle() called while idle"); + + self.method = None; + self.keep_alive.idle(); + + if !self.is_idle() { + self.close(); + return; + } + + self.reading = Reading::Init; + self.writing = Writing::Init; + + // !T::should_read_first() means Client. + // + // If Client connection has just gone idle, the Dispatcher + // should try the poll loop one more time, so as to poll the + // pending requests stream. + if !T::should_read_first() { + self.notify_read = true; + } + } + + fn is_idle(&self) -> bool { + matches!(self.keep_alive.status(), KA::Idle) + } + + fn is_read_closed(&self) -> bool { + matches!(self.reading, Reading::Closed) + } + + fn is_write_closed(&self) -> bool { + matches!(self.writing, Writing::Closed) + } + + fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + let (tx, rx) = crate::upgrade::pending(); + self.upgrade = Some(tx); + rx + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + #[bench] + fn bench_read_head_short(b: &mut ::test::Bencher) { + use super::*; + let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; + let len = s.len(); + b.bytes = len as u64; + + // an empty IO, we'll be skipping and using the read buffer anyways + let io = tokio_test::io::Builder::new().build(); + let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); + *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); + conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + b.iter(|| { + rt.block_on(futures_util::future::poll_fn(|cx| { + match conn.poll_read_head(cx) { + Poll::Ready(Some(Ok(x))) => { + ::test::black_box(&x); + let mut headers = x.0.headers; + headers.clear(); + conn.state.cached_headers = Some(headers); + } + f => panic!("expected Ready(Some(Ok(..))): {:?}", f), + } + + conn.io.read_buf_mut().reserve(1); + unsafe { + conn.io.read_buf_mut().set_len(len); + } + conn.state.reading = Reading::Init; + Poll::Ready(()) + })); + }); + } + + /* + //TODO: rewrite these using dispatch... someday... + use futures::{Async, Future, Stream, Sink}; + use futures::future; + + use proto::{self, ClientTransaction, MessageHead, ServerTransaction}; + use super::super::Encoder; + use mock::AsyncIo; + + use super::{Conn, Decoder, Reading, Writing}; + use ::uri::Uri; + + use std::str::FromStr; + + #[test] + fn test_conn_init_read() { + let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec(); + let len = good_message.len(); + let io = AsyncIo::new_buf(good_message, len); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + + match conn.poll().unwrap() { + Async::Ready(Some(Frame::Message { message, body: false })) => { + assert_eq!(message, MessageHead { + subject: ::proto::RequestLine(::Get, Uri::from_str("/").unwrap()), + .. MessageHead::default() + }) + }, + f => panic!("frame is not Frame::Message: {:?}", f) + } + } + + #[test] + fn test_conn_parse_partial() { + let _: Result<(), ()> = future::lazy(|| { + let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec(); + let io = AsyncIo::new_buf(good_message, 10); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + assert!(conn.poll().unwrap().is_not_ready()); + conn.io.io_mut().block_in(50); + let async = conn.poll().unwrap(); + assert!(async.is_ready()); + match async { + Async::Ready(Some(Frame::Message { .. })) => (), + f => panic!("frame is not Message: {:?}", f), + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_init_read_eof_idle() { + let io = AsyncIo::new_buf(vec![], 1); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_idle_partial_parse() { + let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_busy() { + let _: Result<(), ()> = future::lazy(|| { + // server ignores + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("unexpected frame: {:?}", other) + } + + // client + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_finish_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + conn.state.reading = Reading::Body(Decoder::length(0)); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_message_empty_body_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: false, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_read_body_end() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: true, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: Some(_) }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // When the body is done, `poll` MUST return a `Body` frame with chunk set to `None` + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::NotReady) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_closed_read() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_body_write_length() { + let _ = pretty_env_logger::try_init(); + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096; + conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64)); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; max].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'b'; 1024 * 8].into()) }).unwrap().is_not_ready()); + + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(max * 2); + assert!(conn.poll_complete().unwrap().is_ready()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'c'; 1024 * 8].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_write_chunked() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::chunked()); + + assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'x'; 8192].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_flush() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::length(1024 * 1024)); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + conn.io.io_mut().block_in(1024 * 1024 * 5); + assert!(conn.poll_complete().unwrap().is_ready()); + assert!(conn.can_buffer_body()); + assert!(conn.io.io_mut().flushed()); + + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_parking() { + use std::sync::Arc; + use futures::executor::Notify; + use futures::executor::NotifyHandle; + + struct Car { + permit: bool, + } + impl Notify for Car { + fn notify(&self, _id: usize) { + assert!(self.permit, "unparked without permit"); + } + } + + fn car(permit: bool) -> NotifyHandle { + Arc::new(Car { + permit: permit, + }).into() + } + + // test that once writing is done, unparks + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(true), 0).unwrap(); + + + // test that flushing when not waiting on read doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + + + // test that flushing and writing isn't done doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + conn.state.writing = Writing::Body(Encoder::length(5_000)); + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + } + + #[test] + fn test_conn_closed_write() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) { + Err(_e) => {}, + other => panic!("did not return Err: {:?}", other) + } + + assert!(conn.state.is_write_closed()); + } + + #[test] + fn test_conn_write_empty_chunk() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + + assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(Vec::new().into()) }).unwrap().is_ready()); + conn.start_send(Frame::Body { chunk: Some(vec![b'a'].into()) }).unwrap_err(); + } + */ +} diff --git a/third_party/rust/hyper/src/proto/h1/decode.rs b/third_party/rust/hyper/src/proto/h1/decode.rs new file mode 100644 index 0000000000..1e3a38effc --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/decode.rs @@ -0,0 +1,731 @@ +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::usize; + +use bytes::Bytes; +use tracing::{debug, trace}; + +use crate::common::{task, Poll}; + +use super::io::MemRead; +use super::DecodedLength; + +use self::Kind::{Chunked, Eof, Length}; + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Clone, PartialEq)] +pub(crate) struct Decoder { + kind: Kind, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// The bool tracks when EOF is seen on the transport. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof(bool), +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + Trailer, + TrailerLf, + EndCr, + EndLf, + End, +} + +impl Decoder { + // constructors + + pub(crate) fn length(x: u64) -> Decoder { + Decoder { + kind: Kind::Length(x), + } + } + + pub(crate) fn chunked() -> Decoder { + Decoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub(crate) fn eof() -> Decoder { + Decoder { + kind: Kind::Eof(false), + } + } + + pub(super) fn new(len: DecodedLength) -> Self { + match len { + DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CLOSE_DELIMITED => Decoder::eof(), + length => Decoder::length(length.danger_len()), + } + } + + // methods + + pub(crate) fn is_eof(&self) -> bool { + matches!(self.kind, Length(0) | Chunked(ChunkedState::End, _) | Eof(true)) + } + + pub(crate) fn decode<R: MemRead>( + &mut self, + cx: &mut task::Context<'_>, + body: &mut R, + ) -> Poll<Result<Bytes, io::Error>> { + trace!("decode; state={:?}", self.kind); + match self.kind { + Length(ref mut remaining) => { + if *remaining == 0 { + Poll::Ready(Ok(Bytes::new())) + } else { + let to_read = *remaining as usize; + let buf = ready!(body.read_mem(cx, to_read))?; + let num = buf.as_ref().len() as u64; + if num > *remaining { + *remaining = 0; + } else if num == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } else { + *remaining -= num; + } + Poll::Ready(Ok(buf)) + } + } + Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = ready!(state.step(cx, body, size, &mut buf))?; + if *state == ChunkedState::End { + trace!("end of chunked"); + return Poll::Ready(Ok(Bytes::new())); + } + if let Some(buf) = buf { + return Poll::Ready(Ok(buf)); + } + } + } + Eof(ref mut is_eof) => { + if *is_eof { + Poll::Ready(Ok(Bytes::new())) + } else { + // 8192 chosen because its about 2 packets, there probably + // won't be that much available, so don't have MemReaders + // allocate buffers to big + body.read_mem(cx, 8192).map_ok(|slice| { + *is_eof = slice.is_empty(); + slice + }) + } + } + } + } + + #[cfg(test)] + async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> { + futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await + } +} + +impl fmt::Debug for Decoder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.kind, f) + } +} + +macro_rules! byte ( + ($rdr:ident, $cx:expr) => ({ + let buf = ready!($rdr.read_mem($cx, 1))?; + if !buf.is_empty() { + buf[0] + } else { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, + "unexpected EOF during chunk size line"))); + } + }) +); + +impl ChunkedState { + fn step<R: MemRead>( + &self, + cx: &mut task::Context<'_>, + body: &mut R, + size: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(cx, body, size), + SizeLws => ChunkedState::read_size_lws(cx, body), + Extension => ChunkedState::read_extension(cx, body), + SizeLf => ChunkedState::read_size_lf(cx, body, *size), + Body => ChunkedState::read_body(cx, body, size, buf), + BodyCr => ChunkedState::read_body_cr(cx, body), + BodyLf => ChunkedState::read_body_lf(cx, body), + Trailer => ChunkedState::read_trailer(cx, body), + TrailerLf => ChunkedState::read_trailer_lf(cx, body), + EndCr => ChunkedState::read_end_cr(cx, body), + EndLf => ChunkedState::read_end_lf(cx, body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + fn read_size<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: &mut u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Read chunk hex size"); + + macro_rules! or_overflow { + ($e:expr) => ( + match $e { + Some(val) => val, + None => return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk size: overflow", + ))), + } + ) + } + + let radix = 16; + match byte!(rdr, cx) { + b @ b'0'..=b'9' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b - b'0') as u64)); + } + b @ b'a'..=b'f' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64)); + } + b @ b'A'..=b'F' => { + *size = or_overflow!(size.checked_mul(radix)); + *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64)); + } + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + } + Poll::Ready(Ok(ChunkedState::Size)) + } + fn read_size_lws<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_size_lws"); + match byte!(rdr, cx) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_extension"); + // We don't care about extensions really at all. Just ignore them. + // They "end" at the next CRLF. + // + // However, some implementations may not check for the CR, so to save + // them from themselves, we reject extensions containing plain LF as + // well. + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + b'\n' => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk extension contains newline", + ))), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunk size is {:?}", size); + match byte!(rdr, cx) { + b'\n' => { + if size == 0 { + Poll::Ready(Ok(ChunkedState::EndCr)) + } else { + debug!("incoming chunked header: {0:#X} ({0} bytes)", size); + Poll::Ready(Ok(ChunkedState::Body)) + } + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + rem: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunked read, remaining={:?}", rem); + + // cap remaining bytes at the max capacity of usize + let rem_cap = match *rem { + r if r > usize::MAX as u64 => usize::MAX, + r => r as usize, + }; + + let to_read = rem_cap; + let slice = ready!(rdr.read_mem(cx, to_read))?; + let count = slice.len(); + + if count == 0 { + *rem = 0; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } + *buf = Some(slice); + *rem -= count as u64; + + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + fn read_body_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + + fn read_trailer<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_trailer"); + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)), + _ => Poll::Ready(Ok(ChunkedState::Trailer)), + } + } + fn read_trailer_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid trailer end LF", + ))), + } + } + + fn read_end_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Ok(ChunkedState::Trailer)), + } + } + fn read_end_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[derive(Debug)] +struct IncompleteBody; + +impl fmt::Display for IncompleteBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "end of file before message length reached") + } +} + +impl StdError for IncompleteBody {} + +#[cfg(test)] +mod tests { + use super::*; + use std::pin::Pin; + use std::time::Duration; + use tokio::io::{AsyncRead, ReadBuf}; + + impl<'a> MemRead for &'a [u8] { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + if n > 0 { + let (a, b) = self.split_at(n); + let buf = Bytes::copy_from_slice(a); + *self = b; + Poll::Ready(Ok(buf)) + } else { + Poll::Ready(Ok(Bytes::new())) + } + } + } + + impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let mut v = vec![0; len]; + let mut buf = ReadBuf::new(&mut v); + ready!(Pin::new(self).poll_read(cx, &mut buf)?); + Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled()))) + } + } + + #[cfg(feature = "nightly")] + impl MemRead for Bytes { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + let ret = self.split_to(n); + Poll::Ready(Ok(ret)) + } + } + + /* + use std::io; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use futures::{Async, Poll}; + use bytes::{BytesMut, Bytes}; + use crate::mock::AsyncIo; + */ + + #[tokio::test] + async fn test_read_chunk_size() { + use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof}; + + async fn read(s: &str) -> u64 { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + let desc = format!("read_size failed for {:?}", s); + state = result.expect(desc.as_str()); + if state == ChunkedState::Body || state == ChunkedState::EndCr { + break; + } + } + size + } + + async fn read_err(s: &str, expected_err: io::ErrorKind) { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + state = match result { + Ok(s) => s, + Err(e) => { + assert!( + expected_err == e.kind(), + "Reading {:?}, expected {:?}, but got {:?}", + s, + expected_err, + e.kind() + ); + return; + } + }; + if state == ChunkedState::Body || state == ChunkedState::End { + panic!("Was Ok. Expected Err for {:?}", s); + } + } + } + + assert_eq!(1, read("1\r\n").await); + assert_eq!(1, read("01\r\n").await); + assert_eq!(0, read("0\r\n").await); + assert_eq!(0, read("00\r\n").await); + assert_eq!(10, read("A\r\n").await); + assert_eq!(10, read("a\r\n").await); + assert_eq!(255, read("Ff\r\n").await); + assert_eq!(255, read("Ff \r\n").await); + // Missing LF or CRLF + read_err("F\rF", InvalidInput).await; + read_err("F", UnexpectedEof).await; + // Invalid hex digit + read_err("X\r\n", InvalidInput).await; + read_err("1X\r\n", InvalidInput).await; + read_err("-\r\n", InvalidInput).await; + read_err("-1\r\n", InvalidInput).await; + // Acceptable (if not fully valid) extensions do not influence the size + assert_eq!(1, read("1;extension\r\n").await); + assert_eq!(10, read("a;ext name=value\r\n").await); + assert_eq!(1, read("1;extension;extension2\r\n").await); + assert_eq!(1, read("1;;; ;\r\n").await); + assert_eq!(2, read("2; extension...\r\n").await); + assert_eq!(3, read("3 ; extension=123\r\n").await); + assert_eq!(3, read("3 ;\r\n").await); + assert_eq!(3, read("3 ; \r\n").await); + // Invalid extensions cause an error + read_err("1 invalid extension\r\n", InvalidInput).await; + read_err("1 A\r\n", InvalidInput).await; + read_err("1;no CRLF", UnexpectedEof).await; + read_err("1;reject\nnewlines\r\n", InvalidData).await; + // Overflow + read_err("f0000000000000003\r\n", InvalidData).await; + } + + #[tokio::test] + async fn test_read_sized_early_eof() { + let mut bytes = &b"foo bar"[..]; + let mut decoder = Decoder::length(10); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_early_eof() { + let mut bytes = &b"\ + 9\r\n\ + foo bar\ + "[..]; + let mut decoder = Decoder::chunked(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_single_read() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; + let buf = Decoder::chunked() + .decode_fut(&mut mock_buf) + .await + .expect("decode"); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + } + + #[tokio::test] + async fn test_read_chunked_trailer_with_missing_lf() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..]; + let mut decoder = Decoder::chunked(); + decoder.decode_fut(&mut mock_buf).await.expect("decode"); + let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::InvalidInput); + } + + #[tokio::test] + async fn test_read_chunked_after_eof() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; + let mut decoder = Decoder::chunked(); + + // normal read + let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + + // eof read + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + + // ensure read after eof also returns eof + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + } + + // perform an async read using a custom buffer size and causing a blocking + // read at the specified byte + async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { + let mut outs = Vec::new(); + + let mut ins = if block_at == 0 { + tokio_test::io::Builder::new() + .wait(Duration::from_millis(10)) + .read(content) + .build() + } else { + tokio_test::io::Builder::new() + .read(&content[..block_at]) + .wait(Duration::from_millis(10)) + .read(&content[block_at..]) + .build() + }; + + let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + + loop { + let buf = decoder + .decode_fut(&mut ins) + .await + .expect("unexpected decode error"); + if buf.is_empty() { + break; // eof + } + outs.extend(buf.as_ref()); + } + + String::from_utf8(outs).expect("decode String") + } + + // iterate over the different ways that this async read could go. + // tests blocking a read at each byte along the content - The shotgun approach + async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + let content_len = content.len(); + for block_at in 0..content_len { + let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await; + assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at); + } + } + + #[tokio::test] + async fn test_read_length_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::length(content.len() as u64)).await; + } + + #[tokio::test] + async fn test_read_chunked_async() { + let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; + let expected = "foobar"; + all_async_cases(content, expected, Decoder::chunked()).await; + } + + #[tokio::test] + async fn test_read_eof_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::eof()).await; + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_chunked_1kb(b: &mut test::Bencher) { + let rt = new_runtime(); + + const LEN: usize = 1024; + let mut vec = Vec::new(); + vec.extend(format!("{:x}\r\n", LEN).as_bytes()); + vec.extend(&[0; LEN][..]); + vec.extend(b"\r\n"); + let content = Bytes::from(vec); + + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::chunked(); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_length_1kb(b: &mut test::Bencher) { + let rt = new_runtime(); + + const LEN: usize = 1024; + let content = Bytes::from(&[0; LEN][..]); + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::length(LEN as u64); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + fn new_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("rt build") + } +} diff --git a/third_party/rust/hyper/src/proto/h1/dispatch.rs b/third_party/rust/hyper/src/proto/h1/dispatch.rs new file mode 100644 index 0000000000..677131bfdd --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/dispatch.rs @@ -0,0 +1,750 @@ +use std::error::Error as StdError; + +use bytes::{Buf, Bytes}; +use http::Request; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, trace}; + +use super::{Http1Transaction, Wants}; +use crate::body::{Body, DecodedLength, HttpBody}; +use crate::common::{task, Future, Pin, Poll, Unpin}; +use crate::proto::{ + BodyLength, Conn, Dispatched, MessageHead, RequestHead, +}; +use crate::upgrade::OnUpgrade; + +pub(crate) struct Dispatcher<D, Bs: HttpBody, I, T> { + conn: Conn<I, Bs::Data, T>, + dispatch: D, + body_tx: Option<crate::body::Sender>, + body_rx: Pin<Box<Option<Bs>>>, + is_closing: bool, +} + +pub(crate) trait Dispatch { + type PollItem; + type PollBody; + type PollError; + type RecvItem; + fn poll_msg( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>; + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>; + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>; + fn should_poll(&self) -> bool; +} + +cfg_server! { + use crate::service::HttpService; + + pub(crate) struct Server<S: HttpService<B>, B> { + in_flight: Pin<Box<Option<S::Future>>>, + pub(crate) service: S, + } +} + +cfg_client! { + pin_project_lite::pin_project! { + pub(crate) struct Client<B> { + callback: Option<crate::client::dispatch::Callback<Request<B>, http::Response<Body>>>, + #[pin] + rx: ClientRx<B>, + rx_closed: bool, + } + } + + type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, http::Response<Body>>; +} + +impl<D, Bs, I, T> Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: HttpBody + 'static, + Bs::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + pub(crate) fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self { + Dispatcher { + conn, + dispatch, + body_tx: None, + body_rx: Box::pin(None), + is_closing: false, + } + } + + #[cfg(feature = "server")] + pub(crate) fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive(); + if self.conn.is_write_closed() { + self.close(); + } + } + + pub(crate) fn into_inner(self) -> (I, Bytes, D) { + let (io, buf) = self.conn.into_inner(); + (io, buf, self.dispatch) + } + + /// Run this dispatcher until HTTP says this connection is done, + /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// + /// This is useful for old-style HTTP upgrades, but ignores + /// newer-style upgrade API. + pub(crate) fn poll_without_shutdown( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> + where + Self: Unpin, + { + Pin::new(self).poll_catch(cx, false).map_ok(|ds| { + if let Dispatched::Upgrade(pending) = ds { + pending.manual(); + } + }) + } + + fn poll_catch( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e))?; + Ok(Dispatched::Shutdown) + })) + } + + fn poll_inner( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + T::update_date(); + + ready!(self.poll_loop(cx))?; + + if self.is_done() { + if let Some(pending) = self.conn.pending_upgrade() { + self.conn.take_error()?; + return Poll::Ready(Ok(Dispatched::Upgrade(pending))); + } else if should_shutdown { + ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?; + } + self.conn.take_error()?; + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Pending + } + } + + fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Limit the looping on this connection, in case it is ready far too + // often, so that other futures don't starve. + // + // 16 was chosen arbitrarily, as that is number of pipelined requests + // benchmarks often use. Perhaps it should be a config option instead. + for _ in 0..16 { + let _ = self.poll_read(cx)?; + let _ = self.poll_write(cx)?; + let _ = self.poll_flush(cx)?; + + // This could happen if reading paused before blocking on IO, + // such as getting to the end of a framed message, but then + // writing/flushing set the state back to Init. In that case, + // if the read buffer still had bytes, we'd want to try poll_read + // again, or else we wouldn't ever be woken up again. + // + // Using this instead of task::current() and notify() inside + // the Conn is noticeably faster in pipelined benchmarks. + if !self.conn.wants_read_again() { + //break; + return Poll::Ready(Ok(())); + } + } + + trace!("poll_loop yielding (self = {:p})", self); + + task::yield_now(cx).map(|never| match never {}) + } + + fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.conn.can_read_head() { + ready!(self.poll_read_head(cx))?; + } else if let Some(mut body) = self.body_tx.take() { + if self.conn.can_read_body() { + match body.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Err(_canceled)) => { + // user doesn't care about the body + // so we should stop reading + trace!("body receiver dropped before eof, draining or closing"); + self.conn.poll_drain_or_close_read(cx); + continue; + } + } + match self.conn.poll_read_body(cx) { + Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + }, + Poll::Ready(None) => { + // just drop, the body will close automatically + } + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => { + body.send_error(crate::Error::new_body(e)); + } + } + } else { + // just drop, the body will close automatically + } + } else { + return self.conn.poll_read_keep_alive(cx); + } + } + } + + fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // can dispatch receive, or does it still care about, an incoming message? + match ready!(self.dispatch.poll_ready(cx)) { + Ok(()) => (), + Err(()) => { + trace!("dispatch no longer receiving messages"); + self.close(); + return Poll::Ready(Ok(())); + } + } + // dispatch is ready for a message, try to read one + match ready!(self.conn.poll_read_head(cx)) { + Some(Ok((mut head, body_len, wants))) => { + let body = match body_len { + DecodedLength::ZERO => Body::empty(), + other => { + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); + self.body_tx = Some(tx); + rx + } + }; + if wants.contains(Wants::UPGRADE) { + let upgrade = self.conn.on_upgrade(); + debug_assert!(!upgrade.is_none(), "empty upgrade"); + debug_assert!(head.extensions.get::<OnUpgrade>().is_none(), "OnUpgrade already set"); + head.extensions.insert(upgrade); + } + self.dispatch.recv_msg(Ok((head, body)))?; + Poll::Ready(Ok(())) + } + Some(Err(err)) => { + debug!("read_head error: {}", err); + self.dispatch.recv_msg(Err(err))?; + // if here, the dispatcher gave the user the error + // somewhere else. we still need to shutdown, but + // not as a second error. + self.close(); + Poll::Ready(Ok(())) + } + None => { + // read eof, the write side will have been closed too unless + // allow_read_close was set to true, in which case just do + // nothing... + debug_assert!(self.conn.is_read_closed()); + if self.conn.is_write_closed() { + self.close(); + } + Poll::Ready(Ok(())) + } + } + } + + fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.body_rx.is_none() + && self.conn.can_write_head() + && self.dispatch.should_poll() + { + if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) { + let (head, mut body) = msg.map_err(crate::Error::new_user_service)?; + + // Check if the body knows its full data immediately. + // + // If so, we can skip a bit of bookkeeping that streaming + // bodies need to do. + if let Some(full) = crate::body::take_full_data(&mut body) { + self.conn.write_full_msg(head, full); + return Poll::Ready(Ok(())); + } + + let body_type = if body.is_end_stream() { + self.body_rx.set(None); + None + } else { + let btype = body + .size_hint() + .exact() + .map(BodyLength::Known) + .or_else(|| Some(BodyLength::Unknown)); + self.body_rx.set(Some(body)); + btype + }; + self.conn.write_head(head, body_type); + } else { + self.close(); + return Poll::Ready(Ok(())); + } + } else if !self.conn.can_buffer_body() { + ready!(self.poll_flush(cx))?; + } else { + // A new scope is needed :( + if let (Some(mut body), clear_body) = + OptGuard::new(self.body_rx.as_mut()).guard_mut() + { + debug_assert!(!*clear_body, "opt guard defaults to keeping body"); + if !self.conn.can_write_body() { + trace!( + "no more write body allowed, user body is_end_stream = {}", + body.is_end_stream(), + ); + *clear_body = true; + continue; + } + + let item = ready!(body.as_mut().poll_data(cx)); + if let Some(item) = item { + let chunk = item.map_err(|e| { + *clear_body = true; + crate::Error::new_user_body(e) + })?; + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body()?; + } else { + self.conn.write_body_and_end(chunk); + } + } else { + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); + } + } else { + *clear_body = true; + self.conn.end_body()?; + } + } else { + return Poll::Pending; + } + } + } + } + + fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.conn.poll_flush(cx).map_err(|err| { + debug!("error writing: {}", err); + crate::Error::new_body_write(err) + }) + } + + fn close(&mut self) { + self.is_closing = true; + self.conn.close_read(); + self.conn.close_write(); + } + + fn is_done(&self) -> bool { + if self.is_closing { + return true; + } + + let read_done = self.conn.is_read_closed(); + + if !T::should_read_first() && read_done { + // a client that cannot read may was well be done. + true + } else { + let write_done = self.conn.is_write_closed() + || (!self.dispatch.should_poll() && self.body_rx.is_none()); + read_done && write_done + } + } +} + +impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: HttpBody + 'static, + Bs::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = crate::Result<Dispatched>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_catch(cx, true) + } +} + +// ===== impl OptGuard ===== + +/// A drop guard to allow a mutable borrow of an Option while being able to +/// set whether the `Option` should be cleared on drop. +struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool); + +impl<'a, T> OptGuard<'a, T> { + fn new(pin: Pin<&'a mut Option<T>>) -> Self { + OptGuard(pin, false) + } + + fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) { + (self.0.as_mut().as_pin_mut(), &mut self.1) + } +} + +impl<'a, T> Drop for OptGuard<'a, T> { + fn drop(&mut self) { + if self.1 { + self.0.set(None); + } + } +} + +// ===== impl Server ===== + +cfg_server! { + impl<S, B> Server<S, B> + where + S: HttpService<B>, + { + pub(crate) fn new(service: S) -> Server<S, B> { + Server { + in_flight: Box::pin(None), + service, + } + } + + pub(crate) fn into_service(self) -> S { + self.service + } + } + + // Service is never pinned + impl<S: HttpService<B>, B> Unpin for Server<S, B> {} + + impl<S, Bs> Dispatch for Server<S, Body> + where + S: HttpService<Body, ResBody = Bs>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bs: HttpBody, + { + type PollItem = MessageHead<http::StatusCode>; + type PollBody = Bs; + type PollError = S::Error; + type RecvItem = RequestHead; + + fn poll_msg( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> { + let mut this = self.as_mut(); + let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() { + let resp = ready!(fut.as_mut().poll(cx)?); + let (parts, body) = resp.into_parts(); + let head = MessageHead { + version: parts.version, + subject: parts.status, + headers: parts.headers, + extensions: parts.extensions, + }; + Poll::Ready(Some(Ok((head, body)))) + } else { + unreachable!("poll_msg shouldn't be called if no inflight"); + }; + + // Since in_flight finished, remove it + this.in_flight.set(None); + ret + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + let (msg, body) = msg?; + let mut req = Request::new(body); + *req.method_mut() = msg.subject.0; + *req.uri_mut() = msg.subject.1; + *req.headers_mut() = msg.headers; + *req.version_mut() = msg.version; + *req.extensions_mut() = msg.extensions; + let fut = self.service.call(req); + self.in_flight.set(Some(fut)); + Ok(()) + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + if self.in_flight.is_some() { + Poll::Pending + } else { + self.service.poll_ready(cx).map_err(|_e| { + // FIXME: return error value. + trace!("service closed"); + }) + } + } + + fn should_poll(&self) -> bool { + self.in_flight.is_some() + } + } +} + +// ===== impl Client ===== + +cfg_client! { + impl<B> Client<B> { + pub(crate) fn new(rx: ClientRx<B>) -> Client<B> { + Client { + callback: None, + rx, + rx_closed: false, + } + } + } + + impl<B> Dispatch for Client<B> + where + B: HttpBody, + { + type PollItem = RequestHead; + type PollBody = B; + type PollError = crate::common::Never; + type RecvItem = crate::proto::ResponseHead; + + fn poll_msg( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), crate::common::Never>>> { + let mut this = self.as_mut(); + debug_assert!(!this.rx_closed); + match this.rx.poll_recv(cx) { + Poll::Ready(Some((req, mut cb))) => { + // check that future hasn't been canceled already + match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("request canceled"); + Poll::Ready(None) + } + Poll::Pending => { + let (parts, body) = req.into_parts(); + let head = RequestHead { + version: parts.version, + subject: crate::proto::RequestLine(parts.method, parts.uri), + headers: parts.headers, + extensions: parts.extensions, + }; + this.callback = Some(cb); + Poll::Ready(Some(Ok((head, body)))) + } + } + } + Poll::Ready(None) => { + // user has dropped sender handle + trace!("client tx closed"); + this.rx_closed = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + match msg { + Ok((msg, body)) => { + if let Some(cb) = self.callback.take() { + let res = msg.into_response(body); + cb.send(Ok(res)); + Ok(()) + } else { + // Getting here is likely a bug! An error should have happened + // in Conn::require_empty_read() before ever parsing a + // full message! + Err(crate::Error::new_unexpected_message()) + } + } + Err(err) => { + if let Some(cb) = self.callback.take() { + cb.send(Err((err, None))); + Ok(()) + } else if !self.rx_closed { + self.rx.close(); + if let Some((req, cb)) = self.rx.try_recv() { + trace!("canceling queued request with connection error: {}", err); + // in this case, the message was never even started, so it's safe to tell + // the user that the request was completely canceled + cb.send(Err((crate::Error::new_canceled().with(err), Some(req)))); + Ok(()) + } else { + Err(err) + } + } else { + Err(err) + } + } + } + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + match self.callback { + Some(ref mut cb) => match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("callback receiver has dropped"); + Poll::Ready(Err(())) + } + Poll::Pending => Poll::Ready(Ok(())), + }, + None => Poll::Ready(Err(())), + } + } + + fn should_poll(&self) -> bool { + self.callback.is_none() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::h1::ClientTransaction; + use std::time::Duration; + + #[test] + fn client_read_bytes_before_writing_request() { + let _ = pretty_env_logger::try_init(); + + tokio_test::task::spawn(()).enter(|cx, _| { + let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle(); + + // Block at 0 for now, but we will release this response before + // the request is ready to write later... + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + // First poll is needed to allow tx to send... + assert!(Pin::new(&mut dispatcher).poll(cx).is_pending()); + + // Unblock our IO, which has a response before we've sent request! + // + handle.read(b"HTTP/1.1 200 OK\r\n\r\n"); + + let mut res_rx = tx + .try_send(crate::Request::new(crate::Body::empty())) + .unwrap(); + + tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx)); + let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx)) + .expect_err("callback should send error"); + + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + other => panic!("expected Canceled, got {:?}", other), + } + }); + } + + #[tokio::test] + async fn client_flushing_is_not_ready_for_next_request() { + let _ = pretty_env_logger::try_init(); + + let (io, _handle) = tokio_test::io::Builder::new() + .write(b"POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\n") + .read(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .wait(std::time::Duration::from_secs(2)) + .build_with_handle(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + conn.set_write_strategy_queue(); + + let dispatcher = Dispatcher::new(Client::new(rx), conn); + let _dispatcher = tokio::spawn(async move { dispatcher.await }); + + let req = crate::Request::builder() + .method("POST") + .body(crate::Body::from("reee")) + .unwrap(); + + let res = tx.try_send(req).unwrap().await.expect("response"); + drop(res); + + assert!(!tx.is_ready()); + } + + #[tokio::test] + async fn body_empty_chunks_ignored() { + let _ = pretty_env_logger::try_init(); + + let io = tokio_test::io::Builder::new() + // no reading or writing, just be blocked for the test... + .wait(Duration::from_secs(5)) + .build(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); + + // First poll is needed to allow tx to send... + assert!(dispatcher.poll().is_pending()); + + let body = { + let (mut tx, body) = crate::Body::channel(); + tx.try_send_data("".into()).unwrap(); + body + }; + + let _res_rx = tx.try_send(crate::Request::new(body)).unwrap(); + + // Ensure conn.write_body wasn't called with the empty chunk. + // If it is, it will trigger an assertion. + assert!(dispatcher.poll().is_pending()); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/encode.rs b/third_party/rust/hyper/src/proto/h1/encode.rs new file mode 100644 index 0000000000..f0aa261a4f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/encode.rs @@ -0,0 +1,439 @@ +use std::fmt; +use std::io::IoSlice; + +use bytes::buf::{Chain, Take}; +use bytes::Buf; +use tracing::trace; + +use super::io::WriteBuf; + +type StaticBuf = &'static [u8]; + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Encoder { + kind: Kind, + is_last: bool, +} + +#[derive(Debug)] +pub(crate) struct EncodedBuf<B> { + kind: BufKind<B>, +} + +#[derive(Debug)] +pub(crate) struct NotEof(u64); + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked, + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when neither Content-Length nor Chunked encoding is set. + /// + /// This is mostly only used with HTTP/1.0 with a length. This kind requires + /// the connection to be closed when the body is finished. + #[cfg(feature = "server")] + CloseDelimited, +} + +#[derive(Debug)] +enum BufKind<B> { + Exact(B), + Limited(Take<B>), + Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), + ChunkedEnd(StaticBuf), +} + +impl Encoder { + fn new(kind: Kind) -> Encoder { + Encoder { + kind, + is_last: false, + } + } + pub(crate) fn chunked() -> Encoder { + Encoder::new(Kind::Chunked) + } + + pub(crate) fn length(len: u64) -> Encoder { + Encoder::new(Kind::Length(len)) + } + + #[cfg(feature = "server")] + pub(crate) fn close_delimited() -> Encoder { + Encoder::new(Kind::CloseDelimited) + } + + pub(crate) fn is_eof(&self) -> bool { + matches!(self.kind, Kind::Length(0)) + } + + #[cfg(feature = "server")] + pub(crate) fn set_last(mut self, is_last: bool) -> Self { + self.is_last = is_last; + self + } + + pub(crate) fn is_last(&self) -> bool { + self.is_last + } + + pub(crate) fn is_close_delimited(&self) -> bool { + match self.kind { + #[cfg(feature = "server")] + Kind::CloseDelimited => true, + _ => false, + } + } + + pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { + match self.kind { + Kind::Length(0) => Ok(None), + Kind::Chunked => Ok(Some(EncodedBuf { + kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), + })), + #[cfg(feature = "server")] + Kind::CloseDelimited => Ok(None), + Kind::Length(n) => Err(NotEof(n)), + } + } + + pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + let kind = match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n" as &'static [u8]); + BufKind::Chunked(buf) + } + Kind::Length(ref mut remaining) => { + trace!("sized write, len = {}", len); + if len as u64 > *remaining { + let limit = *remaining as usize; + *remaining = 0; + BufKind::Limited(msg.take(limit)) + } else { + *remaining -= len as u64; + BufKind::Exact(msg) + } + } + #[cfg(feature = "server")] + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + BufKind::Exact(msg) + } + }; + EncodedBuf { kind } + } + + pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + !self.is_last + } + Kind::Length(remaining) => { + use std::cmp::Ordering; + + trace!("sized write, len = {}", len); + match (len as u64).cmp(&remaining) { + Ordering::Equal => { + dst.buffer(msg); + !self.is_last + } + Ordering::Greater => { + dst.buffer(msg.take(remaining as usize)); + !self.is_last + } + Ordering::Less => { + dst.buffer(msg); + false + } + } + } + #[cfg(feature = "server")] + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + dst.buffer(msg); + false + } + } + } + + /// Encodes the full body, without verifying the remaining length matches. + /// + /// This is used in conjunction with HttpBody::__hyper_full_data(), which + /// means we can trust that the buf has the correct size (the buf itself + /// was checked to make the headers). + pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) + where + B: Buf, + { + debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); + debug_assert!( + match self.kind { + Kind::Length(len) => len == msg.remaining() as u64, + _ => true, + }, + "danger_full_buf length mismatches" + ); + + match self.kind { + Kind::Chunked => { + let len = msg.remaining(); + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + } + _ => { + dst.buffer(msg); + } + } + } +} + +impl<B> Buf for EncodedBuf<B> +where + B: Buf, +{ + #[inline] + fn remaining(&self) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.remaining(), + BufKind::Limited(ref b) => b.remaining(), + BufKind::Chunked(ref b) => b.remaining(), + BufKind::ChunkedEnd(ref b) => b.remaining(), + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match self.kind { + BufKind::Exact(ref b) => b.chunk(), + BufKind::Limited(ref b) => b.chunk(), + BufKind::Chunked(ref b) => b.chunk(), + BufKind::ChunkedEnd(ref b) => b.chunk(), + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match self.kind { + BufKind::Exact(ref mut b) => b.advance(cnt), + BufKind::Limited(ref mut b) => b.advance(cnt), + BufKind::Chunked(ref mut b) => b.advance(cnt), + BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.chunks_vectored(dst), + BufKind::Limited(ref b) => b.chunks_vectored(dst), + BufKind::Chunked(ref b) => b.chunks_vectored(dst), + BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), + } + } +} + +#[cfg(target_pointer_width = "32")] +const USIZE_BYTES: usize = 4; + +#[cfg(target_pointer_width = "64")] +const USIZE_BYTES: usize = 8; + +// each byte will become 2 hex +const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; + +#[derive(Clone, Copy)] +struct ChunkSize { + bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], + pos: u8, + len: u8, +} + +impl ChunkSize { + fn new(len: usize) -> ChunkSize { + use std::fmt::Write; + let mut size = ChunkSize { + bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], + pos: 0, + len: 0, + }; + write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); + size + } +} + +impl Buf for ChunkSize { + #[inline] + fn remaining(&self) -> usize { + (self.len - self.pos).into() + } + + #[inline] + fn chunk(&self) -> &[u8] { + &self.bytes[self.pos.into()..self.len.into()] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining()); + self.pos += cnt as u8; // just asserted cnt fits in u8 + } +} + +impl fmt::Debug for ChunkSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkSize") + .field("bytes", &&self.bytes[..self.len.into()]) + .field("pos", &self.pos) + .finish() + } +} + +impl fmt::Write for ChunkSize { + fn write_str(&mut self, num: &str) -> fmt::Result { + use std::io::Write; + (&mut self.bytes[self.len.into()..]) + .write_all(num.as_bytes()) + .expect("&mut [u8].write() cannot error"); + self.len += num.len() as u8; // safe because bytes is never bigger than 256 + Ok(()) + } +} + +impl<B: Buf> From<B> for EncodedBuf<B> { + fn from(buf: B) -> Self { + EncodedBuf { + kind: BufKind::Exact(buf), + } + } +} + +impl<B: Buf> From<Take<B>> for EncodedBuf<B> { + fn from(buf: Take<B>) -> Self { + EncodedBuf { + kind: BufKind::Limited(buf), + } + } +} + +impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { + fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { + EncodedBuf { + kind: BufKind::Chunked(buf), + } + } +} + +impl fmt::Display for NotEof { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "early end, expected {} more bytes", self.0) + } +} + +impl std::error::Error for NotEof {} + +#[cfg(test)] +mod tests { + use bytes::BufMut; + + use super::super::io::Cursor; + use super::Encoder; + + #[test] + fn chunked() { + let mut encoder = Encoder::chunked(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + assert_eq!(dst, b"7\r\nfoo bar\r\n"); + + let msg2 = b"baz quux herp".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); + + let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); + dst.put(end); + + assert_eq!( + dst, + b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() + ); + } + + #[test] + fn length() { + let max_len = 8; + let mut encoder = Encoder::length(max_len as u64); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst.len(), max_len); + assert_eq!(dst, b"foo barb"); + assert!(encoder.is_eof()); + assert!(encoder.end::<()>().unwrap().is_none()); + } + + #[test] + fn eof() { + let mut encoder = Encoder::close_delimited(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"foo barbaz"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap(); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/io.rs b/third_party/rust/hyper/src/proto/h1/io.rs new file mode 100644 index 0000000000..1d251e2c84 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/io.rs @@ -0,0 +1,1002 @@ +use std::cmp; +use std::fmt; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::future::Future; +use std::io::{self, IoSlice}; +use std::marker::Unpin; +use std::mem::MaybeUninit; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Duration; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; +use tracing::{debug, trace}; + +use super::{Http1Transaction, ParseContext, ParsedMessage}; +use crate::common::buf::BufList; +use crate::common::{task, Pin, Poll}; + +/// The initial buffer size allocated before trying to read from IO. +pub(crate) const INIT_BUFFER_SIZE: usize = 8192; + +/// The minimum value that can be set to max buffer size. +pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; + +/// The default maximum read buffer size. If the buffer gets this big and +/// a message is still not complete, a `TooLarge` error is triggered. +// Note: if this changes, update server::conn::Http::max_buf_size docs. +pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +/// The maximum number of distinct `Buf`s to hold in a list before requiring +/// a flush. Only affects when the buffer strategy is to queue buffers. +/// +/// Note that a flush can happen before reaching the maximum. This simply +/// forces a flush if the queue gets this big. +const MAX_BUF_LIST_BUFFERS: usize = 16; + +pub(crate) struct Buffered<T, B> { + flush_pipeline: bool, + io: T, + read_blocked: bool, + read_buf: BytesMut, + read_buf_strategy: ReadStrategy, + write_buf: WriteBuf<B>, +} + +impl<T, B> fmt::Debug for Buffered<T, B> +where + B: Buf, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Buffered") + .field("read_buf", &self.read_buf) + .field("write_buf", &self.write_buf) + .finish() + } +} + +impl<T, B> Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + pub(crate) fn new(io: T) -> Buffered<T, B> { + let strategy = if io.is_write_vectored() { + WriteStrategy::Queue + } else { + WriteStrategy::Flatten + }; + let write_buf = WriteBuf::new(strategy); + Buffered { + flush_pipeline: false, + io, + read_blocked: false, + read_buf: BytesMut::with_capacity(0), + read_buf_strategy: ReadStrategy::default(), + write_buf, + } + } + + #[cfg(feature = "server")] + pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { + debug_assert!(!self.write_buf.has_remaining()); + self.flush_pipeline = enabled; + if enabled { + self.set_write_strategy_flatten(); + } + } + + pub(crate) fn set_max_buf_size(&mut self, max: usize) { + assert!( + max >= MINIMUM_MAX_BUFFER_SIZE, + "The max_buf_size cannot be smaller than {}.", + MINIMUM_MAX_BUFFER_SIZE, + ); + self.read_buf_strategy = ReadStrategy::with_max(max); + self.write_buf.max_buf_size = max; + } + + #[cfg(feature = "client")] + pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { + self.read_buf_strategy = ReadStrategy::Exact(sz); + } + + pub(crate) fn set_write_strategy_flatten(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Flatten); + } + + pub(crate) fn set_write_strategy_queue(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Queue); + } + + pub(crate) fn read_buf(&self) -> &[u8] { + self.read_buf.as_ref() + } + + #[cfg(test)] + #[cfg(feature = "nightly")] + pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + + /// Return the "allocated" available space, not the potential space + /// that could be allocated in the future. + fn read_buf_remaining_mut(&self) -> usize { + self.read_buf.capacity() - self.read_buf.len() + } + + /// Return whether we can append to the headers buffer. + /// + /// Reasons we can't: + /// - The write buf is in queue mode, and some of the past body is still + /// needing to be flushed. + pub(crate) fn can_headers_buf(&self) -> bool { + !self.write_buf.queue.has_remaining() + } + + pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> { + let buf = self.write_buf.headers_mut(); + &mut buf.bytes + } + + pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { + &mut self.write_buf + } + + pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { + self.write_buf.buffer(buf) + } + + pub(crate) fn can_buffer(&self) -> bool { + self.flush_pipeline || self.write_buf.can_buffer() + } + + pub(crate) fn consume_leading_lines(&mut self) { + if !self.read_buf.is_empty() { + let mut i = 0; + while i < self.read_buf.len() { + match self.read_buf[i] { + b'\r' | b'\n' => i += 1, + _ => break, + } + } + self.read_buf.advance(i); + } + } + + pub(super) fn parse<S>( + &mut self, + cx: &mut task::Context<'_>, + parse_ctx: ParseContext<'_>, + ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> + where + S: Http1Transaction, + { + loop { + match super::role::parse_headers::<S>( + &mut self.read_buf, + ParseContext { + cached_headers: parse_ctx.cached_headers, + req_method: parse_ctx.req_method, + h1_parser_config: parse_ctx.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: parse_ctx.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, + preserve_header_case: parse_ctx.preserve_header_case, + #[cfg(feature = "ffi")] + preserve_header_order: parse_ctx.preserve_header_order, + h09_responses: parse_ctx.h09_responses, + #[cfg(feature = "ffi")] + on_informational: parse_ctx.on_informational, + #[cfg(feature = "ffi")] + raw_headers: parse_ctx.raw_headers, + }, + )? { + Some(msg) => { + debug!("parsed {} headers", msg.head.headers.len()); + + #[cfg(all(feature = "server", feature = "runtime"))] + { + *parse_ctx.h1_header_read_timeout_running = false; + + if let Some(h1_header_read_timeout_fut) = + parse_ctx.h1_header_read_timeout_fut + { + // Reset the timer in order to avoid woken up when the timeout finishes + h1_header_read_timeout_fut + .as_mut() + .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); + } + } + return Poll::Ready(Ok(msg)); + } + None => { + let max = self.read_buf_strategy.max(); + if self.read_buf.len() >= max { + debug!("max_buf_size ({}) reached, closing", max); + return Poll::Ready(Err(crate::Error::new_too_large())); + } + + #[cfg(all(feature = "server", feature = "runtime"))] + if *parse_ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout_fut) = + parse_ctx.h1_header_read_timeout_fut + { + if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { + *parse_ctx.h1_header_read_timeout_running = false; + + tracing::warn!("read header from client timeout"); + return Poll::Ready(Err(crate::Error::new_header_timeout())); + } + } + } + } + } + if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { + trace!("parse eof"); + return Poll::Ready(Err(crate::Error::new_incomplete())); + } + } + } + + pub(crate) fn poll_read_from_io( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<io::Result<usize>> { + self.read_blocked = false; + let next = self.read_buf_strategy.next(); + if self.read_buf_remaining_mut() < next { + self.read_buf.reserve(next); + } + + let dst = self.read_buf.chunk_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + Poll::Ready(Ok(_)) => { + let n = buf.filled().len(); + trace!("received {} bytes", n); + unsafe { + // Safety: we just read that many bytes into the + // uninitialized part of the buffer, so this is okay. + // @tokio pls give me back `poll_read_buf` thanks + self.read_buf.advance_mut(n); + } + self.read_buf_strategy.record(n); + Poll::Ready(Ok(n)) + } + Poll::Pending => { + self.read_blocked = true; + Poll::Pending + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + + pub(crate) fn into_inner(self) -> (T, Bytes) { + (self.io, self.read_buf.freeze()) + } + + pub(crate) fn io_mut(&mut self) -> &mut T { + &mut self.io + } + + pub(crate) fn is_read_blocked(&self) -> bool { + self.read_blocked + } + + pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + if self.flush_pipeline && !self.read_buf.is_empty() { + Poll::Ready(Ok(())) + } else if self.write_buf.remaining() == 0 { + Pin::new(&mut self.io).poll_flush(cx) + } else { + if let WriteStrategy::Flatten = self.write_buf.strategy { + return self.poll_flush_flattened(cx); + } + + const MAX_WRITEV_BUFS: usize = 64; + loop { + let n = { + let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS]; + let len = self.write_buf.chunks_vectored(&mut iovs); + ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))? + }; + // TODO(eliza): we have to do this manually because + // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when + // `poll_write_buf` comes back, the manual advance will need to leave! + self.write_buf.advance(n); + debug!("flushed {} bytes", n); + if self.write_buf.remaining() == 0 { + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + } + + /// Specialized version of `flush` when strategy is Flatten. + /// + /// Since all buffered bytes are flattened into the single headers buffer, + /// that skips some bookkeeping around using multiple buffers. + fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + loop { + let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?; + debug!("flushed {} bytes", n); + self.write_buf.headers.advance(n); + if self.write_buf.headers.remaining() == 0 { + self.write_buf.headers.reset(); + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + + #[cfg(test)] + fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { + futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) + } +} + +// The `B` is a `Buf`, we never project a pin to it +impl<T: Unpin, B> Unpin for Buffered<T, B> {} + +// TODO: This trait is old... at least rename to PollBytes or something... +pub(crate) trait MemRead { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; +} + +impl<T, B> MemRead for Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + if !self.read_buf.is_empty() { + let n = std::cmp::min(len, self.read_buf.len()); + Poll::Ready(Ok(self.read_buf.split_to(n).freeze())) + } else { + let n = ready!(self.poll_read_from_io(cx))?; + Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze())) + } + } +} + +#[derive(Clone, Copy, Debug)] +enum ReadStrategy { + Adaptive { + decrease_now: bool, + next: usize, + max: usize, + }, + #[cfg(feature = "client")] + Exact(usize), +} + +impl ReadStrategy { + fn with_max(max: usize) -> ReadStrategy { + ReadStrategy::Adaptive { + decrease_now: false, + next: INIT_BUFFER_SIZE, + max, + } + } + + fn next(&self) -> usize { + match *self { + ReadStrategy::Adaptive { next, .. } => next, + #[cfg(feature = "client")] + ReadStrategy::Exact(exact) => exact, + } + } + + fn max(&self) -> usize { + match *self { + ReadStrategy::Adaptive { max, .. } => max, + #[cfg(feature = "client")] + ReadStrategy::Exact(exact) => exact, + } + } + + fn record(&mut self, bytes_read: usize) { + match *self { + ReadStrategy::Adaptive { + ref mut decrease_now, + ref mut next, + max, + .. + } => { + if bytes_read >= *next { + *next = cmp::min(incr_power_of_two(*next), max); + *decrease_now = false; + } else { + let decr_to = prev_power_of_two(*next); + if bytes_read < decr_to { + if *decrease_now { + *next = cmp::max(decr_to, INIT_BUFFER_SIZE); + *decrease_now = false; + } else { + // Decreasing is a two "record" process. + *decrease_now = true; + } + } else { + // A read within the current range should cancel + // a potential decrease, since we just saw proof + // that we still need this size. + *decrease_now = false; + } + } + } + #[cfg(feature = "client")] + ReadStrategy::Exact(_) => (), + } + } +} + +fn incr_power_of_two(n: usize) -> usize { + n.saturating_mul(2) +} + +fn prev_power_of_two(n: usize) -> usize { + // Only way this shift can underflow is if n is less than 4. + // (Which would means `usize::MAX >> 64` and underflowed!) + debug_assert!(n >= 4); + (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 +} + +impl Default for ReadStrategy { + fn default() -> ReadStrategy { + ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) + } +} + +#[derive(Clone)] +pub(crate) struct Cursor<T> { + bytes: T, + pos: usize, +} + +impl<T: AsRef<[u8]>> Cursor<T> { + #[inline] + pub(crate) fn new(bytes: T) -> Cursor<T> { + Cursor { bytes, pos: 0 } + } +} + +impl Cursor<Vec<u8>> { + /// If we've advanced the position a bit in this cursor, and wish to + /// extend the underlying vector, we may wish to unshift the "read" bytes + /// off, and move everything else over. + fn maybe_unshift(&mut self, additional: usize) { + if self.pos == 0 { + // nothing to do + return; + } + + if self.bytes.capacity() - self.bytes.len() >= additional { + // there's room! + return; + } + + self.bytes.drain(0..self.pos); + self.pos = 0; + } + + fn reset(&mut self) { + self.pos = 0; + self.bytes.clear(); + } +} + +impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Cursor") + .field("pos", &self.pos) + .field("len", &self.bytes.as_ref().len()) + .finish() + } +} + +impl<T: AsRef<[u8]>> Buf for Cursor<T> { + #[inline] + fn remaining(&self) -> usize { + self.bytes.as_ref().len() - self.pos + } + + #[inline] + fn chunk(&self) -> &[u8] { + &self.bytes.as_ref()[self.pos..] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); + self.pos += cnt; + } +} + +// an internal buffer to collect writes before flushes +pub(super) struct WriteBuf<B> { + /// Re-usable buffer that holds message headers + headers: Cursor<Vec<u8>>, + max_buf_size: usize, + /// Deque of user buffers if strategy is Queue + queue: BufList<B>, + strategy: WriteStrategy, +} + +impl<B: Buf> WriteBuf<B> { + fn new(strategy: WriteStrategy) -> WriteBuf<B> { + WriteBuf { + headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, + queue: BufList::new(), + strategy, + } + } +} + +impl<B> WriteBuf<B> +where + B: Buf, +{ + fn set_strategy(&mut self, strategy: WriteStrategy) { + self.strategy = strategy; + } + + pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { + debug_assert!(buf.has_remaining()); + match self.strategy { + WriteStrategy::Flatten => { + let head = self.headers_mut(); + + head.maybe_unshift(buf.remaining()); + trace!( + self.len = head.remaining(), + buf.len = buf.remaining(), + "buffer.flatten" + ); + //perf: This is a little faster than <Vec as BufMut>>::put, + //but accomplishes the same result. + loop { + let adv = { + let slice = buf.chunk(); + if slice.is_empty() { + return; + } + head.bytes.extend_from_slice(slice); + slice.len() + }; + buf.advance(adv); + } + } + WriteStrategy::Queue => { + trace!( + self.len = self.remaining(), + buf.len = buf.remaining(), + "buffer.queue" + ); + self.queue.push(buf.into()); + } + } + } + + fn can_buffer(&self) -> bool { + match self.strategy { + WriteStrategy::Flatten => self.remaining() < self.max_buf_size, + WriteStrategy::Queue => { + self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size + } + } + } + + fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { + debug_assert!(!self.queue.has_remaining()); + &mut self.headers + } +} + +impl<B: Buf> fmt::Debug for WriteBuf<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WriteBuf") + .field("remaining", &self.remaining()) + .field("strategy", &self.strategy) + .finish() + } +} + +impl<B: Buf> Buf for WriteBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.headers.remaining() + self.queue.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + let headers = self.headers.chunk(); + if !headers.is_empty() { + headers + } else { + self.queue.chunk() + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + let hrem = self.headers.remaining(); + + match hrem.cmp(&cnt) { + cmp::Ordering::Equal => self.headers.reset(), + cmp::Ordering::Greater => self.headers.advance(cnt), + cmp::Ordering::Less => { + let qcnt = cnt - hrem; + self.headers.reset(); + self.queue.advance(qcnt); + } + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + let n = self.headers.chunks_vectored(dst); + self.queue.chunks_vectored(&mut dst[n..]) + n + } +} + +#[derive(Debug)] +enum WriteStrategy { + Flatten, + Queue, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use tokio_test::io::Builder as Mock; + + // #[cfg(feature = "nightly")] + // use test::Bencher; + + /* + impl<T: Read> MemRead for AsyncIo<T> { + fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { + let mut v = vec![0; len]; + let n = try_nb!(self.read(v.as_mut_slice())); + Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + } + } + */ + + #[tokio::test] + #[ignore] + async fn iobuf_write_empty_slice() { + // TODO(eliza): can i have writev back pls T_T + // // First, let's just check that the Mock would normally return an + // // error on an unexpected write, even if the buffer is empty... + // let mut mock = Mock::new().build(); + // futures_util::future::poll_fn(|cx| { + // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) + // }) + // .await + // .expect_err("should be a broken pipe"); + + // // underlying io will return the logic error upon write, + // // so we are testing that the io_buf does not trigger a write + // // when there is nothing to flush + // let mock = Mock::new().build(); + // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + // io_buf.flush().await.expect("should short-circuit flush"); + } + + #[tokio::test] + async fn parse_reads_until_blocked() { + use crate::proto::h1::ClientTransaction; + + let _ = pretty_env_logger::try_init(); + let mock = Mock::new() + // Split over multiple reads will read all of it + .read(b"HTTP/1.1 200 OK\r\n") + .read(b"Server: hyper\r\n") + // missing last line ending + .wait(Duration::from_secs(1)) + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // We expect a `parse` to be not ready, and so can't await it directly. + // Rather, this `poll_fn` will wrap the `Poll` result. + futures_util::future::poll_fn(|cx| { + let parse_ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + assert!(buffered + .parse::<ClientTransaction>(cx, parse_ctx) + .is_pending()); + Poll::Ready(()) + }) + .await; + + assert_eq!( + buffered.read_buf, + b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..] + ); + } + + #[test] + fn read_strategy_adaptive_increments() { + let mut strategy = ReadStrategy::default(); + assert_eq!(strategy.next(), 8192); + + // Grows if record == next + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(16384); + assert_eq!(strategy.next(), 32768); + + // Enormous records still increment at same rate + strategy.record(::std::usize::MAX); + assert_eq!(strategy.next(), 65536); + + let max = strategy.max(); + while strategy.next() < max { + strategy.record(max); + } + + assert_eq!(strategy.next(), max, "never goes over max"); + strategy.record(max + 1); + assert_eq!(strategy.next(), max, "never goes over max"); + } + + #[test] + fn read_strategy_adaptive_decrements() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + strategy.record(8192); + assert_eq!(strategy.next(), 16384, "record was with range"); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "in-range record should make this the 'first' again" + ); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "second smaller record decrements"); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "first doesn't decrement"); + strategy.record(1); + assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum"); + } + + #[test] + fn read_strategy_adaptive_stays_the_same() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "with current step does not decrement" + ); + } + + #[test] + fn read_strategy_adaptive_max_fuzz() { + fn fuzz(max: usize) { + let mut strategy = ReadStrategy::with_max(max); + while strategy.next() < max { + strategy.record(::std::usize::MAX); + } + let mut next = strategy.next(); + while next > 8192 { + strategy.record(1); + strategy.record(1); + next = strategy.next(); + assert!( + next.is_power_of_two(), + "decrement should be powers of two: {} (max = {})", + next, + max, + ); + } + } + + let mut max = 8192; + while max < std::usize::MAX { + fuzz(max); + max = (max / 2).saturating_mul(3); + } + fuzz(::std::usize::MAX); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] // needs to trigger a debug_assert + fn write_buf_requires_non_empty_bufs() { + let mock = Mock::new().build(); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + buffered.buffer(Cursor::new(Vec::new())); + } + + /* + TODO: needs tokio_test::io to allow configure write_buf calls + #[test] + fn write_buf_queue() { + let _ = pretty_env_logger::try_init(); + + let mock = AsyncIo::new_buf(vec![], 1024); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + buffered.flush().unwrap(); + + assert_eq!(buffered.io, b"hello world, it's hyper!"); + assert_eq!(buffered.io.num_writes(), 1); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + */ + + #[tokio::test] + async fn write_buf_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new().write(b"hello world, it's hyper!").build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Flatten); + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + + buffered.flush().await.expect("flush"); + } + + #[test] + fn write_buf_flatten_partially_flushed() { + let _ = pretty_env_logger::try_init(); + + let b = |s: &str| Cursor::new(s.as_bytes().to_vec()); + + let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten); + + write_buf.buffer(b("hello ")); + write_buf.buffer(b("world, ")); + + assert_eq!(write_buf.chunk(), b"hello world, "); + + // advance most of the way, but not all + write_buf.advance(11); + + assert_eq!(write_buf.chunk(), b", "); + assert_eq!(write_buf.headers.pos, 11); + assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE); + + // there's still room in the headers buffer, so just push on the end + write_buf.buffer(b("it's hyper!")); + + assert_eq!(write_buf.chunk(), b", it's hyper!"); + assert_eq!(write_buf.headers.pos, 11); + + let rem1 = write_buf.remaining(); + let cap = write_buf.headers.bytes.capacity(); + + // but when this would go over capacity, don't copy the old bytes + write_buf.buffer(Cursor::new(vec![b'X'; cap])); + assert_eq!(write_buf.remaining(), cap + rem1); + assert_eq!(write_buf.headers.pos, 0); + } + + #[tokio::test] + async fn write_buf_queue_disable_auto() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + .write(b"hello ") + .write(b"world, ") + .write(b"it's ") + .write(b"hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Queue); + + // we have 4 buffers, and vec IO disabled, but explicitly said + // don't try to auto detect (via setting strategy above) + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + // #[cfg(feature = "nightly")] + // #[bench] + // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { + // let s = "Hello, World!"; + // b.bytes = s.len() as u64; + + // let mut write_buf = WriteBuf::<bytes::Bytes>::new(); + // write_buf.set_strategy(WriteStrategy::Flatten); + // b.iter(|| { + // let chunk = bytes::Bytes::from(s); + // write_buf.buffer(chunk); + // ::test::black_box(&write_buf); + // write_buf.headers.bytes.clear(); + // }) + // } +} diff --git a/third_party/rust/hyper/src/proto/h1/mod.rs b/third_party/rust/hyper/src/proto/h1/mod.rs new file mode 100644 index 0000000000..5a2587a843 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/mod.rs @@ -0,0 +1,122 @@ +#[cfg(all(feature = "server", feature = "runtime"))] +use std::{pin::Pin, time::Duration}; + +use bytes::BytesMut; +use http::{HeaderMap, Method}; +use httparse::ParserConfig; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; + +use crate::body::DecodedLength; +use crate::proto::{BodyLength, MessageHead}; + +pub(crate) use self::conn::Conn; +pub(crate) use self::decode::Decoder; +pub(crate) use self::dispatch::Dispatcher; +pub(crate) use self::encode::{EncodedBuf, Encoder}; +//TODO: move out of h1::io +pub(crate) use self::io::MINIMUM_MAX_BUFFER_SIZE; + +mod conn; +mod decode; +pub(crate) mod dispatch; +mod encode; +mod io; +mod role; + +cfg_client! { + pub(crate) type ClientTransaction = role::Client; +} + +cfg_server! { + pub(crate) type ServerTransaction = role::Server; +} + +pub(crate) trait Http1Transaction { + type Incoming; + type Outgoing: Default; + const LOG: &'static str; + fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>; + fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>; + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>; + + fn is_client() -> bool { + !Self::is_server() + } + + fn is_server() -> bool { + !Self::is_client() + } + + fn should_error_on_parse_eof() -> bool { + Self::is_client() + } + + fn should_read_first() -> bool { + Self::is_server() + } + + fn update_date() {} +} + +/// Result newtype for Http1Transaction::parse. +pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>; + +#[derive(Debug)] +pub(crate) struct ParsedMessage<T> { + head: MessageHead<T>, + decode: DecodedLength, + expect_continue: bool, + keep_alive: bool, + wants_upgrade: bool, +} + +pub(crate) struct ParseContext<'a> { + cached_headers: &'a mut Option<HeaderMap>, + req_method: &'a mut Option<Method>, + h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option<Duration>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &'a mut Option<Pin<Box<Sleep>>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &'a mut bool, + preserve_header_case: bool, + #[cfg(feature = "ffi")] + preserve_header_order: bool, + h09_responses: bool, + #[cfg(feature = "ffi")] + on_informational: &'a mut Option<crate::ffi::OnInformational>, + #[cfg(feature = "ffi")] + raw_headers: bool, +} + +/// Passed to Http1Transaction::encode +pub(crate) struct Encode<'a, T> { + head: &'a mut MessageHead<T>, + body: Option<BodyLength>, + #[cfg(feature = "server")] + keep_alive: bool, + req_method: &'a mut Option<Method>, + title_case_headers: bool, +} + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/third_party/rust/hyper/src/proto/h1/role.rs b/third_party/rust/hyper/src/proto/h1/role.rs new file mode 100644 index 0000000000..6252207baf --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/role.rs @@ -0,0 +1,2847 @@ +use std::fmt::{self, Write}; +use std::mem::MaybeUninit; + +use bytes::Bytes; +use bytes::BytesMut; +#[cfg(feature = "server")] +use http::header::ValueIter; +use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::{HeaderMap, Method, StatusCode, Version}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; +use tracing::{debug, error, trace, trace_span, warn}; + +use crate::body::DecodedLength; +#[cfg(feature = "server")] +use crate::common::date; +use crate::error::Parse; +use crate::ext::HeaderCaseMap; +#[cfg(feature = "ffi")] +use crate::ext::OriginalHeaderOrder; +use crate::headers; +use crate::proto::h1::{ + Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, +}; +use crate::proto::{BodyLength, MessageHead, RequestHead, RequestLine}; + +const MAX_HEADERS: usize = 100; +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific +#[cfg(feature = "server")] +const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; + +macro_rules! header_name { + ($bytes:expr) => {{ + { + match HeaderName::from_bytes($bytes) { + Ok(name) => name, + Err(e) => maybe_panic!(e), + } + } + }}; +} + +macro_rules! header_value { + ($bytes:expr) => {{ + { + unsafe { HeaderValue::from_maybe_shared_unchecked($bytes) } + } + }}; +} + +macro_rules! maybe_panic { + ($($arg:tt)*) => ({ + let _err = ($($arg)*); + if cfg!(debug_assertions) { + panic!("{:?}", _err); + } else { + error!("Internal Hyper error, please report {:?}", _err); + return Err(Parse::Internal) + } + }) +} + +pub(super) fn parse_headers<T>( + bytes: &mut BytesMut, + ctx: ParseContext<'_>, +) -> ParseResult<T::Incoming> +where + T: Http1Transaction, +{ + // If the buffer is empty, don't bother entering the span, it's just noise. + if bytes.is_empty() { + return Ok(None); + } + + let span = trace_span!("parse_headers"); + let _s = span.enter(); + + #[cfg(all(feature = "server", feature = "runtime"))] + if !*ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout) = ctx.h1_header_read_timeout { + let deadline = Instant::now() + h1_header_read_timeout; + *ctx.h1_header_read_timeout_running = true; + match ctx.h1_header_read_timeout_fut { + Some(h1_header_read_timeout_fut) => { + debug!("resetting h1 header read timeout timer"); + h1_header_read_timeout_fut.as_mut().reset(deadline); + } + None => { + debug!("setting h1 header read timeout timer"); + *ctx.h1_header_read_timeout_fut = + Some(Box::pin(tokio::time::sleep_until(deadline))); + } + } + } + } + + T::parse(bytes, ctx) +} + +pub(super) fn encode_headers<T>( + enc: Encode<'_, T::Outgoing>, + dst: &mut Vec<u8>, +) -> crate::Result<Encoder> +where + T: Http1Transaction, +{ + let span = trace_span!("encode_headers"); + let _s = span.enter(); + T::encode(enc, dst) +} + +// There are 2 main roles, Client and Server. + +#[cfg(feature = "client")] +pub(crate) enum Client {} + +#[cfg(feature = "server")] +pub(crate) enum Server {} + +#[cfg(feature = "server")] +impl Http1Transaction for Server { + type Incoming = RequestLine; + type Outgoing = StatusCode; + const LOG: &'static str = "{role=server}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<RequestLine> { + debug_assert!(!buf.is_empty(), "parse called with empty buf"); + + let mut keep_alive; + let is_http_11; + let subject; + let version; + let len; + let headers_len; + + // Unsafe: both headers_indices and headers are using uninitialized memory, + // but we *never* read any of it until after httparse has assigned + // values into it. By not zeroing out the stack memory, this saves + // a good ~5% on pipeline benchmarks. + let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + MaybeUninit::uninit().assume_init() + }; + { + /* SAFETY: it is safe to go from MaybeUninit array to array of MaybeUninit */ + let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + trace!(bytes = buf.len(), "Request.parse"); + let mut req = httparse::Request::new(&mut []); + let bytes = buf.as_ref(); + match req.parse_with_uninit_headers(bytes, &mut headers) { + Ok(httparse::Status::Complete(parsed_len)) => { + trace!("Request.parse Complete({})", parsed_len); + len = parsed_len; + let uri = req.path.unwrap(); + if uri.len() > MAX_URI_LEN { + return Err(Parse::UriTooLong); + } + subject = RequestLine( + Method::from_bytes(req.method.unwrap().as_bytes())?, + uri.parse()?, + ); + version = if req.version.unwrap() == 1 { + keep_alive = true; + is_http_11 = true; + Version::HTTP_11 + } else { + keep_alive = false; + is_http_11 = false; + Version::HTTP_10 + }; + + record_header_indices(bytes, &req.headers, &mut headers_indices)?; + headers_len = req.headers.len(); + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(err) => { + return Err(match err { + // if invalid Token, try to determine if for method or path + httparse::Error::Token => { + if req.method.is_none() { + Parse::Method + } else { + debug_assert!(req.path.is_none()); + Parse::Uri + } + } + other => other.into(), + }); + } + } + }; + + let slice = buf.split_to(len).freeze(); + + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. (irrelevant to Request) + // 2. (irrelevant to Request) + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Length 0. + // 7. (irrelevant to Request) + + let mut decoder = DecodedLength::ZERO; + let mut expect_continue = false; + let mut con_len = None; + let mut is_te = false; + let mut is_te_chunked = false; + let mut wants_upgrade = subject.0 == Method::CONNECT; + + let mut header_case_map = if ctx.preserve_header_case { + Some(HeaderCaseMap::default()) + } else { + None + }; + + #[cfg(feature = "ffi")] + let mut header_order = if ctx.preserve_header_order { + Some(OriginalHeaderOrder::default()) + } else { + None + }; + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + match name { + header::TRANSFER_ENCODING => { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if !is_http_11 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + return Err(Parse::transfer_encoding_unexpected()); + } + is_te = true; + if headers::is_chunked_(&value) { + is_te_chunked = true; + decoder = DecodedLength::CHUNKED; + } else { + is_te_chunked = false; + } + } + header::CONTENT_LENGTH => { + if is_te { + continue; + } + let len = headers::content_length_parse(&value) + .ok_or_else(Parse::content_length_invalid)?; + if let Some(prev) = con_len { + if prev != len { + debug!( + "multiple Content-Length headers with different values: [{}, {}]", + prev, len, + ); + return Err(Parse::content_length_invalid()); + } + // we don't need to append this secondary length + continue; + } + decoder = DecodedLength::checked_new(len)?; + con_len = Some(len); + } + header::CONNECTION => { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + header::EXPECT => { + // According to https://datatracker.ietf.org/doc/html/rfc2616#section-14.20 + // Comparison of expectation values is case-insensitive for unquoted tokens + // (including the 100-continue token) + expect_continue = value.as_bytes().eq_ignore_ascii_case(b"100-continue"); + } + header::UPGRADE => { + // Upgrades are only allowed with HTTP/1.1 + wants_upgrade = is_http_11; + } + + _ => (), + } + + if let Some(ref mut header_case_map) = header_case_map { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + + #[cfg(feature = "ffi")] + if let Some(ref mut header_order) = header_order { + header_order.append(&name); + } + + headers.append(name, value); + } + + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::transfer_encoding_invalid()); + } + + let mut extensions = http::Extensions::default(); + + if let Some(header_case_map) = header_case_map { + extensions.insert(header_case_map); + } + + #[cfg(feature = "ffi")] + if let Some(header_order) = header_order { + extensions.insert(header_order); + } + + *ctx.req_method = Some(subject.0.clone()); + + Ok(Some(ParsedMessage { + head: MessageHead { + version, + subject, + headers, + extensions, + }, + decode: decoder, + expect_continue, + keep_alive, + wants_upgrade, + })) + } + + fn encode(mut msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Server::encode status={:?}, body={:?}, req_method={:?}", + msg.head.subject, + msg.body, + msg.req_method + ); + + let mut wrote_len = false; + + // hyper currently doesn't support returning 1xx status codes as a Response + // This is because Service only allows returning a single Response, and + // so if you try to reply with a e.g. 100 Continue, you have no way of + // replying with the latter status code response. + let (ret, is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + (Ok(()), true) + } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() { + // Sending content-length or transfer-encoding header on 2xx response + // to CONNECT is forbidden in RFC 7231. + wrote_len = true; + (Ok(()), true) + } else if msg.head.subject.is_informational() { + warn!("response with 1xx status code not supported"); + *msg.head = MessageHead::default(); + msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; + (Err(crate::Error::new_user_unsupported_status_code()), true) + } else { + (Ok(()), !msg.keep_alive) + }; + + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let orig_len = dst.len(); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + let custom_reason_phrase = msg.head.extensions.get::<crate::ext::ReasonPhrase>(); + + if msg.head.version == Version::HTTP_11 + && msg.head.subject == StatusCode::OK + && custom_reason_phrase.is_none() + { + extend(dst, b"HTTP/1.1 200 OK\r\n"); + } else { + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0 "), + Version::HTTP_11 => extend(dst, b"HTTP/1.1 "), + Version::HTTP_2 => { + debug!("response with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1 "); + } + other => panic!("unexpected response version: {:?}", other), + } + + extend(dst, msg.head.subject.as_str().as_bytes()); + extend(dst, b" "); + + if let Some(reason) = custom_reason_phrase { + extend(dst, reason.as_bytes()); + } else { + // a reason MUST be written, as many parsers will expect it. + extend( + dst, + msg.head + .subject + .canonical_reason() + .unwrap_or("<none>") + .as_bytes(), + ); + } + + extend(dst, b"\r\n"); + } + + let orig_headers; + let extensions = std::mem::take(&mut msg.head.extensions); + let orig_headers = match extensions.get::<HeaderCaseMap>() { + None if msg.title_case_headers => { + orig_headers = HeaderCaseMap::default(); + Some(&orig_headers) + } + orig_headers => orig_headers, + }; + let encoder = if let Some(orig_headers) = orig_headers { + Self::encode_headers_with_original_case( + msg, + dst, + is_last, + orig_len, + wrote_len, + orig_headers, + )? + } else { + Self::encode_headers_with_lower_case(msg, dst, is_last, orig_len, wrote_len)? + }; + + ret.map(|()| encoder) + } + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + use crate::error::Kind; + let status = match *err.kind() { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header(_)) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG, + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + + fn is_server() -> bool { + true + } + + fn update_date() { + date::update(); + } +} + +#[cfg(feature = "server")] +impl Server { + fn can_have_body(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_chunked(method, status) + } + + fn can_chunked(method: &Option<Method>, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() + { + false + } else if status.is_informational() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn can_have_content_length(method: &Option<Method>, status: StatusCode) -> bool { + if status.is_informational() || method == &Some(Method::CONNECT) && status.is_success() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn can_have_implicit_zero_content_length(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_have_content_length(method, status) && method != &Some(Method::HEAD) + } + + fn encode_headers_with_lower_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + is_last: bool, + orig_len: usize, + wrote_len: bool, + ) -> crate::Result<Encoder> { + struct LowercaseWriter; + + impl HeaderNameWriter for LowercaseWriter { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + line: &str, + _: (HeaderName, &str), + ) { + extend(dst, line.as_bytes()) + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + name_with_colon: &str, + _: HeaderName, + ) { + extend(dst, name_with_colon.as_bytes()) + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) { + extend(dst, name.as_str().as_bytes()) + } + } + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, LowercaseWriter) + } + + #[cold] + #[inline(never)] + fn encode_headers_with_original_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + is_last: bool, + orig_len: usize, + wrote_len: bool, + orig_headers: &HeaderCaseMap, + ) -> crate::Result<Encoder> { + struct OrigCaseWriter<'map> { + map: &'map HeaderCaseMap, + current: Option<(HeaderName, ValueIter<'map, Bytes>)>, + title_case_headers: bool, + } + + impl HeaderNameWriter for OrigCaseWriter<'_> { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + _: &str, + (name, rest): (HeaderName, &str), + ) { + self.write_header_name(dst, &name); + extend(dst, rest.as_bytes()); + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + _: &str, + name: HeaderName, + ) { + self.write_header_name(dst, &name); + extend(dst, b": "); + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName) { + let Self { + map, + ref mut current, + title_case_headers, + } = *self; + if current.as_ref().map_or(true, |(last, _)| last != name) { + *current = None; + } + let (_, values) = + current.get_or_insert_with(|| (name.clone(), map.get_all_internal(name))); + + if let Some(orig_name) = values.next() { + extend(dst, orig_name); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); + } else { + extend(dst, name.as_str().as_bytes()); + } + } + } + + let header_name_writer = OrigCaseWriter { + map: orig_headers, + current: None, + title_case_headers: msg.title_case_headers, + }; + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, header_name_writer) + } + + #[inline] + fn encode_headers<W>( + msg: Encode<'_, StatusCode>, + dst: &mut Vec<u8>, + mut is_last: bool, + orig_len: usize, + mut wrote_len: bool, + mut header_name_writer: W, + ) -> crate::Result<Encoder> + where + W: HeaderNameWriter, + { + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let rewind = |dst: &mut Vec<u8>| { + dst.truncate(orig_len); + }; + + let mut encoder = Encoder::length(0); + let mut wrote_date = false; + let mut cur_name = None; + let mut is_name_written = false; + let mut must_write_chunked = false; + let mut prev_con_len = None; + + macro_rules! handle_is_name_written { + () => {{ + if is_name_written { + // we need to clean up and write the newline + debug_assert_ne!( + &dst[dst.len() - 2..], + b"\r\n", + "previous header wrote newline but set is_name_written" + ); + + if must_write_chunked { + extend(dst, b", chunked\r\n"); + } else { + extend(dst, b"\r\n"); + } + } + }}; + } + + 'headers: for (opt_name, value) in msg.head.headers.drain() { + if let Some(n) = opt_name { + cur_name = Some(n); + handle_is_name_written!(); + is_name_written = false; + } + let name = cur_name.as_ref().expect("current header name"); + match *name { + header::CONTENT_LENGTH => { + if wrote_len && !is_name_written { + warn!("unexpected content-length found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + match msg.body { + Some(BodyLength::Known(known_len)) => { + // The HttpBody claims to know a length, and + // the headers are already set. For performance + // reasons, we are just going to trust that + // the values match. + // + // In debug builds, we'll assert they are the + // same to help developers find bugs. + #[cfg(debug_assertions)] + { + if let Some(len) = headers::content_length_parse(&value) { + assert!( + len == known_len, + "payload claims content-length of {}, custom content-length header claims {}", + known_len, + len, + ); + } + } + + if !is_name_written { + encoder = Encoder::length(known_len); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + } + continue 'headers; + } + Some(BodyLength::Unknown) => { + // The HttpBody impl didn't know how long the + // body is, but a length header was included. + // We have to parse the value to return our + // Encoder... + + if let Some(len) = headers::content_length_parse(&value) { + if let Some(prev) = prev_con_len { + if prev != len { + warn!( + "multiple Content-Length values found: [{}, {}]", + prev, len + ); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + debug_assert!(is_name_written); + continue 'headers; + } else { + // we haven't written content-length yet! + encoder = Encoder::length(len); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + prev_con_len = Some(len); + continue 'headers; + } + } else { + warn!("illegal Content-Length value: {:?}", value); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + } + None => { + // We have no body to actually send, + // but the headers claim a content-length. + // There's only 2 ways this makes sense: + // + // - The header says the length is `0`. + // - This is a response to a `HEAD` request. + if msg.req_method == &Some(Method::HEAD) { + debug_assert_eq!(encoder, Encoder::length(0)); + } else { + if value.as_bytes() != b"0" { + warn!( + "content-length value found, but empty body provided: {:?}", + value + ); + } + continue 'headers; + } + } + } + wrote_len = true; + } + header::TRANSFER_ENCODING => { + if wrote_len && !is_name_written { + warn!("unexpected transfer-encoding found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + wrote_len = true; + // Must check each value, because `chunked` needs to be the + // last encoding, or else we add it. + must_write_chunked = !headers::is_chunked_(&value); + + if !is_name_written { + encoder = Encoder::chunked(); + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "transfer-encoding: ", + header::TRANSFER_ENCODING, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::CONNECTION => { + if !is_last && headers::connection_close(&value) { + is_last = true; + } + if !is_name_written { + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "connection: ", + header::CONNECTION, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::DATE => { + wrote_date = true; + } + _ => (), + } + //TODO: this should perhaps instead combine them into + //single lines, as RFC7230 suggests is preferable. + + // non-special write Name and Value + debug_assert!( + !is_name_written, + "{:?} set is_name_written and didn't continue loop", + name, + ); + header_name_writer.write_header_name(dst, name); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + + handle_is_name_written!(); + + if !wrote_len { + encoder = match msg.body { + Some(BodyLength::Unknown) => { + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + Encoder::close_delimited() + } else { + header_name_writer.write_full_header_line( + dst, + "transfer-encoding: chunked\r\n", + (header::TRANSFER_ENCODING, ": chunked\r\n"), + ); + Encoder::chunked() + } + } + None | Some(BodyLength::Known(0)) => { + if Server::can_have_implicit_zero_content_length( + msg.req_method, + msg.head.subject, + ) { + header_name_writer.write_full_header_line( + dst, + "content-length: 0\r\n", + (header::CONTENT_LENGTH, ": 0\r\n"), + ) + } + Encoder::length(0) + } + Some(BodyLength::Known(len)) => { + if !Server::can_have_content_length(msg.req_method, msg.head.subject) { + Encoder::length(0) + } else { + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); + extend(dst, ::itoa::Buffer::new().format(len).as_bytes()); + extend(dst, b"\r\n"); + Encoder::length(len) + } + } + }; + } + + if !Server::can_have_body(msg.req_method, msg.head.subject) { + trace!( + "server body forced to 0; method={:?}, status={:?}", + msg.req_method, + msg.head.subject + ); + encoder = Encoder::length(0); + } + + // cached date is much faster than formatting every request + if !wrote_date { + dst.reserve(date::DATE_VALUE_LENGTH + 8); + header_name_writer.write_header_name_with_colon(dst, "date: ", header::DATE); + date::extend(dst); + extend(dst, b"\r\n\r\n"); + } else { + extend(dst, b"\r\n"); + } + + Ok(encoder.set_last(is_last)) + } +} + +#[cfg(feature = "server")] +trait HeaderNameWriter { + fn write_full_header_line( + &mut self, + dst: &mut Vec<u8>, + line: &str, + name_value_pair: (HeaderName, &str), + ); + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec<u8>, + name_with_colon: &str, + name: HeaderName, + ); + fn write_header_name(&mut self, dst: &mut Vec<u8>, name: &HeaderName); +} + +#[cfg(feature = "client")] +impl Http1Transaction for Client { + type Incoming = StatusCode; + type Outgoing = RequestLine; + const LOG: &'static str = "{role=client}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> { + debug_assert!(!buf.is_empty(), "parse called with empty buf"); + + // Loop to skip information status code headers (100 Continue, etc). + loop { + // Unsafe: see comment in Server Http1Transaction, above. + let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + MaybeUninit::uninit().assume_init() + }; + let (len, status, reason, version, headers_len) = { + // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit + let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + trace!(bytes = buf.len(), "Response.parse"); + let mut res = httparse::Response::new(&mut []); + let bytes = buf.as_ref(); + match ctx.h1_parser_config.parse_response_with_uninit_headers( + &mut res, + bytes, + &mut headers, + ) { + Ok(httparse::Status::Complete(len)) => { + trace!("Response.parse Complete({})", len); + let status = StatusCode::from_u16(res.code.unwrap())?; + + let reason = { + let reason = res.reason.unwrap(); + // Only save the reason phrase if it isn't the canonical reason + if Some(reason) != status.canonical_reason() { + Some(Bytes::copy_from_slice(reason.as_bytes())) + } else { + None + } + }; + + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + record_header_indices(bytes, &res.headers, &mut headers_indices)?; + let headers_len = res.headers.len(); + (len, status, reason, version, headers_len) + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(httparse::Error::Version) if ctx.h09_responses => { + trace!("Response.parse accepted HTTP/0.9 response"); + + (0, StatusCode::OK, None, Version::HTTP_09, 0) + } + Err(e) => return Err(e.into()), + } + }; + + let mut slice = buf.split_to(len); + + if ctx + .h1_parser_config + .obsolete_multiline_headers_in_responses_are_allowed() + { + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + for b in &mut slice[header.value.0..header.value.1] { + if *b == b'\r' || *b == b'\n' { + *b = b' '; + } + } + } + } + + let slice = slice.freeze(); + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + let mut keep_alive = version == Version::HTTP_11; + + let mut header_case_map = if ctx.preserve_header_case { + Some(HeaderCaseMap::default()) + } else { + None + }; + + #[cfg(feature = "ffi")] + let mut header_order = if ctx.preserve_header_order { + Some(OriginalHeaderOrder::default()) + } else { + None + }; + + headers.reserve(headers_len); + for header in &headers_indices[..headers_len] { + // SAFETY: array is valid up to `headers_len` + let header = unsafe { &*header.as_ptr() }; + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + if let header::CONNECTION = name { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + + if let Some(ref mut header_case_map) = header_case_map { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + + #[cfg(feature = "ffi")] + if let Some(ref mut header_order) = header_order { + header_order.append(&name); + } + + headers.append(name, value); + } + + let mut extensions = http::Extensions::default(); + + if let Some(header_case_map) = header_case_map { + extensions.insert(header_case_map); + } + + #[cfg(feature = "ffi")] + if let Some(header_order) = header_order { + extensions.insert(header_order); + } + + if let Some(reason) = reason { + // Safety: httparse ensures that only valid reason phrase bytes are present in this + // field. + let reason = unsafe { crate::ext::ReasonPhrase::from_bytes_unchecked(reason) }; + extensions.insert(reason); + } + + #[cfg(feature = "ffi")] + if ctx.raw_headers { + extensions.insert(crate::ffi::RawHeaders(crate::ffi::hyper_buf(slice))); + } + + let head = MessageHead { + version, + subject: status, + headers, + extensions, + }; + if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { + return Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + // a client upgrade means the connection can't be used + // again, as it is definitely upgrading. + keep_alive: keep_alive && !is_upgrade, + wants_upgrade: is_upgrade, + })); + } + + #[cfg(feature = "ffi")] + if head.subject.is_informational() { + if let Some(callback) = ctx.on_informational { + callback.call(head.into_response(crate::Body::empty())); + } + } + + // Parsing a 1xx response could have consumed the buffer, check if + // it is empty now... + if buf.is_empty() { + return Ok(None); + } + } + } + + fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Client::encode method={:?}, body={:?}", + msg.head.subject.0, + msg.body + ); + + *msg.req_method = Some(msg.head.subject.0.clone()); + + let body = Client::set_length(msg.head, msg.body); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + extend(dst, msg.head.subject.0.as_str().as_bytes()); + extend(dst, b" "); + //TODO: add API to http::Uri to encode without std::fmt + let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1); + + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0"), + Version::HTTP_11 => extend(dst, b"HTTP/1.1"), + Version::HTTP_2 => { + debug!("request with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1"); + } + other => panic!("unexpected request version: {:?}", other), + } + extend(dst, b"\r\n"); + + if let Some(orig_headers) = msg.head.extensions.get::<HeaderCaseMap>() { + write_headers_original_case( + &msg.head.headers, + orig_headers, + dst, + msg.title_case_headers, + ); + } else if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + + extend(dst, b"\r\n"); + msg.head.headers.clear(); //TODO: remove when switching to drain() + + Ok(body) + } + + fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + // we can't tell the server about any errors it creates + None + } + + fn is_client() -> bool { + true + } +} + +#[cfg(feature = "client")] +impl Client { + /// Returns Some(length, wants_upgrade) if successful. + /// + /// Returns None if this message head should be skipped (like a 100 status). + fn decoder( + inc: &MessageHead<StatusCode>, + method: &mut Option<Method>, + ) -> Result<Option<(DecodedLength, bool)>, Parse> { + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. (irrelevant to Response) + // 7. Read till EOF. + + match inc.subject.as_u16() { + 101 => { + return Ok(Some((DecodedLength::ZERO, true))); + } + 100 | 102..=199 => { + trace!("ignoring informational response: {}", inc.subject.as_u16()); + return Ok(None); + } + 204 | 304 => return Ok(Some((DecodedLength::ZERO, false))), + _ => (), + } + match *method { + Some(Method::HEAD) => { + return Ok(Some((DecodedLength::ZERO, false))); + } + Some(Method::CONNECT) => { + if let 200..=299 = inc.subject.as_u16() { + return Ok(Some((DecodedLength::ZERO, true))); + } + } + Some(_) => {} + None => { + trace!("Client::decoder is missing the Method"); + } + } + + if inc.headers.contains_key(header::TRANSFER_ENCODING) { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if inc.version == Version::HTTP_10 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + Err(Parse::transfer_encoding_unexpected()) + } else if headers::transfer_encoding_is_chunked(&inc.headers) { + Ok(Some((DecodedLength::CHUNKED, false))) + } else { + trace!("not chunked, read till eof"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { + Ok(Some((DecodedLength::checked_new(len)?, false))) + } else if inc.headers.contains_key(header::CONTENT_LENGTH) { + debug!("illegal Content-Length header"); + Err(Parse::content_length_invalid()) + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } + fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder { + let body = if let Some(body) = body { + body + } else { + head.headers.remove(header::TRANSFER_ENCODING); + return Encoder::length(0); + }; + + // HTTP/1.0 doesn't know about chunked + let can_chunked = head.version == Version::HTTP_11; + let headers = &mut head.headers; + + // If the user already set specific headers, we should respect them, regardless + // of what the HttpBody knows about itself. They set them for a reason. + + // Because of the borrow checker, we can't check the for an existing + // Content-Length header while holding an `Entry` for the Transfer-Encoding + // header, so unfortunately, we must do the check here, first. + + let existing_con_len = headers::content_length_parse_all(headers); + let mut should_remove_con_len = false; + + if !can_chunked { + // Chunked isn't legal, so if it is set, we need to remove it. + if headers.remove(header::TRANSFER_ENCODING).is_some() { + trace!("removing illegal transfer-encoding header"); + } + + return if let Some(len) = existing_con_len { + Encoder::length(len) + } else if let BodyLength::Known(len) = body { + set_content_length(headers, len) + } else { + // HTTP/1.0 client requests without a content-length + // cannot have any body at all. + Encoder::length(0) + }; + } + + // If the user set a transfer-encoding, respect that. Let's just + // make sure `chunked` is the final encoding. + let encoder = match headers.entry(header::TRANSFER_ENCODING) { + Entry::Occupied(te) => { + should_remove_con_len = true; + if headers::is_chunked(te.iter()) { + Some(Encoder::chunked()) + } else { + warn!("user provided transfer-encoding does not end in 'chunked'"); + + // There's a Transfer-Encoding, but it doesn't end in 'chunked'! + // An example that could trigger this: + // + // Transfer-Encoding: gzip + // + // This can be bad, depending on if this is a request or a + // response. + // + // - A request is illegal if there is a `Transfer-Encoding` + // but it doesn't end in `chunked`. + // - A response that has `Transfer-Encoding` but doesn't + // end in `chunked` isn't illegal, it just forces this + // to be close-delimited. + // + // We can try to repair this, by adding `chunked` ourselves. + + headers::add_chunked(te); + Some(Encoder::chunked()) + } + } + Entry::Vacant(te) => { + if let Some(len) = existing_con_len { + Some(Encoder::length(len)) + } else if let BodyLength::Unknown = body { + // GET, HEAD, and CONNECT almost never have bodies. + // + // So instead of sending a "chunked" body with a 0-chunk, + // assume no body here. If you *must* send a body, + // set the headers explicitly. + match head.subject.0 { + Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)), + _ => { + te.insert(HeaderValue::from_static("chunked")); + Some(Encoder::chunked()) + } + } + } else { + None + } + } + }; + + // This is because we need a second mutable borrow to remove + // content-length header. + if let Some(encoder) = encoder { + if should_remove_con_len && existing_con_len.is_some() { + headers.remove(header::CONTENT_LENGTH); + } + return encoder; + } + + // User didn't set transfer-encoding, AND we know body length, + // so we can just set the Content-Length automatically. + + let len = if let BodyLength::Known(len) = body { + len + } else { + unreachable!("BodyLength::Unknown would set chunked"); + }; + + set_content_length(headers, len) + } +} + +fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { + // At this point, there should not be a valid Content-Length + // header. However, since we'll be indexing in anyways, we can + // warn the user if there was an existing illegal header. + // + // Or at least, we can in theory. It's actually a little bit slower, + // so perhaps only do that while the user is developing/testing. + + if cfg!(debug_assertions) { + match headers.entry(header::CONTENT_LENGTH) { + Entry::Occupied(mut cl) => { + // Internal sanity check, we should have already determined + // that the header was illegal before calling this function. + debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none()); + // Uh oh, the user set `Content-Length` headers, but set bad ones. + // This would be an illegal message anyways, so let's try to repair + // with our known good length. + error!("user provided content-length header was invalid"); + + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + Entry::Vacant(cl) => { + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + } + } else { + headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len)); + Encoder::length(len) + } +} + +#[derive(Clone, Copy)] +struct HeaderIndices { + name: (usize, usize), + value: (usize, usize), +} + +fn record_header_indices( + bytes: &[u8], + headers: &[httparse::Header<'_>], + indices: &mut [MaybeUninit<HeaderIndices>], +) -> Result<(), crate::error::Parse> { + let bytes_ptr = bytes.as_ptr() as usize; + + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + if header.name.len() >= (1 << 16) { + debug!("header name larger than 64kb: {:?}", header.name); + return Err(crate::error::Parse::TooLarge); + } + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + + // FIXME(maybe_uninit_extra) + // FIXME(addr_of) + // Currently we don't have `ptr::addr_of_mut` in stable rust or + // MaybeUninit::write, so this is some way of assigning into a MaybeUninit + // safely + let new_header_indices = HeaderIndices { + name: (name_start, name_end), + value: (value_start, value_end), + }; + *indices = MaybeUninit::new(new_header_indices); + } + + Ok(()) +} + +// Write header names as title case. The header name is assumed to be ASCII. +fn title_case(dst: &mut Vec<u8>, name: &[u8]) { + dst.reserve(name.len()); + + // Ensure first character is uppercased + let mut prev = b'-'; + for &(mut c) in name { + if prev == b'-' { + c.make_ascii_uppercase(); + } + dst.push(c); + prev = c; + } +} + +fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + title_case(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +#[cold] +fn write_headers_original_case( + headers: &HeaderMap, + orig_case: &HeaderCaseMap, + dst: &mut Vec<u8>, + title_case_headers: bool, +) { + // For each header name/value pair, there may be a value in the casemap + // that corresponds to the HeaderValue. So, we iterator all the keys, + // and for each one, try to pair the originally cased name with the value. + // + // TODO: consider adding http::HeaderMap::entries() iterator + for name in headers.keys() { + let mut names = orig_case.get_all(name); + + for value in headers.get_all(name) { + if let Some(orig_name) = names.next() { + extend(dst, orig_name.as_ref()); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); + } else { + extend(dst, name.as_str().as_bytes()); + } + + // Wanted for curl test cases that send `X-Custom-Header:\r\n` + if value.is_empty() { + extend(dst, b":\r\n"); + } else { + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + } + } +} + +struct FastWrite<'a>(&'a mut Vec<u8>); + +impl<'a> fmt::Write for FastWrite<'a> { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + extend(self.0, s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +fn extend(dst: &mut Vec<u8>, data: &[u8]) { + dst.extend_from_slice(data); +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + + use super::*; + + #[test] + fn test_parse_request() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut method = None; + let msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut None, + req_method: &mut method, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject.0, crate::Method::GET); + assert_eq!(msg.head.subject.1, "/echo"); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Host"], "hyper.rs"); + assert_eq!(method, Some(crate::Method::GET)); + } + + #[test] + fn test_parse_response() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Content-Length"], "0"); + } + + #[test] + fn test_parse_request_errors() { + let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Server::parse(&mut raw, ctx).unwrap_err(); + } + + const H09_RESPONSE: &'static str = "Baguettes are super delicious, don't you agree?"; + + #[test] + fn test_parse_response_h09_allowed() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(H09_RESPONSE); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: true, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw, H09_RESPONSE); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_09); + assert_eq!(msg.head.headers.len(), 0); + } + + #[test] + fn test_parse_response_h09_rejected() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(H09_RESPONSE); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Client::parse(&mut raw, ctx).unwrap_err(); + assert_eq!(raw, H09_RESPONSE); + } + + const RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON: &'static str = + "HTTP/1.1 200 OK\r\nAccess-Control-Allow-Credentials : true\r\n\r\n"; + + #[test] + fn test_parse_allow_response_with_spaces_before_colons() { + use httparse::ParserConfig; + + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON); + let mut h1_parser_config = ParserConfig::default(); + h1_parser_config.allow_spaces_after_header_name_in_responses(true); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config, + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Access-Control-Allow-Credentials"], "true"); + } + + #[test] + fn test_parse_reject_response_with_spaces_before_colons() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + Client::parse(&mut raw, ctx).unwrap_err(); + } + + #[test] + fn test_parse_preserve_header_case_in_request() { + let mut raw = + BytesMut::from("GET / HTTP/1.1\r\nHost: hyper.rs\r\nX-BREAD: baguette\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: true, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }; + let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); + let orig_headers = parsed_message + .head + .extensions + .get::<HeaderCaseMap>() + .unwrap(); + assert_eq!( + orig_headers + .get_all_internal(&HeaderName::from_static("host")) + .into_iter() + .collect::<Vec<_>>(), + vec![&Bytes::from("Host")] + ); + assert_eq!( + orig_headers + .get_all_internal(&HeaderName::from_static("x-bread")) + .into_iter() + .collect::<Vec<_>>(), + vec![&Bytes::from("X-BREAD")] + ); + } + + #[test] + fn test_decoder_request() { + fn parse(s: &str) -> ParsedMessage<RequestLine> { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str, comment: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect_err(comment) + } + + // no length or transfer-encoding means 0-length body + assert_eq!( + parse( + "\ + GET / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip, chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // content-length + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // multiple content-lengths of same value are fine + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // multiple content-lengths with different values is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 11\r\n\ + \r\n\ + ", + "multiple content-lengths", + ); + + // content-length with prefix is not allowed + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: +10\r\n\ + \r\n\ + ", + "prefixed content-length", + ); + + // transfer-encoding that isn't chunked is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + \r\n\ + ", + "transfer-encoding but not chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked, gzip\r\n\ + \r\n\ + ", + "transfer-encoding doesn't end in chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + transfer-encoding: afterlol\r\n\ + \r\n\ + ", + "transfer-encoding multiple lines doesn't end in chunked", + ); + + // http/1.0 + + assert_eq!( + parse( + "\ + POST / HTTP/1.0\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // 1.0 doesn't understand chunked, so its an error + parse_err( + "\ + POST / HTTP/1.0\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + "1.0 chunked", + ); + } + + #[test] + fn test_decoder_response() { + fn parse(s: &str) -> ParsedMessage<StatusCode> { + parse_with_method(s, Method::GET) + } + + fn parse_ignores(s: &str) { + let mut bytes = BytesMut::from(s); + assert!(Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + } + ) + .expect("parse ok") + .is_none()) + } + + fn parse_with_method(s: &str, m: Method) -> ParsedMessage<StatusCode> { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(m), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect_err("parse should err") + } + + // no content-length or transfer-encoding means close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 204 and 304 never have a body + assert_eq!( + parse( + "\ + HTTP/1.1 204 No Content\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 304 Not Modified\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // content-length + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 9\r\n\ + \r\n\ + ", + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: +8\r\n\ + \r\n\ + ", + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // transfer-encoding not-chunked is close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: yolo\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // HEAD can have content-length, but not body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ", + Method::HEAD + ) + .decode, + DecodedLength::ZERO + ); + + // CONNECT with 200 never has body + { + let msg = parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", + Method::CONNECT, + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be upgrade"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // CONNECT receiving non 200 can have a body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 400 Bad Request\r\n\ + \r\n\ + ", + Method::CONNECT + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1xx status codes + parse_ignores( + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + ", + ); + + parse_ignores( + "\ + HTTP/1.1 103 Early Hints\r\n\ + \r\n\ + ", + ); + + // 101 upgrade not supported yet + { + let msg = parse( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + ", + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be last"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // http/1.0 + assert_eq!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1.0 doesn't understand chunked + parse_err( + "\ + HTTP/1.0 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + ); + + // keep-alive + assert!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.1 keep-alive is default" + ); + + assert!( + !parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, close, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection close is always close" + ); + + assert!( + !parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.0 close is default" + ); + + assert!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, keep-alive, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection keep-alive is always keep-alive" + ); + } + + #[test] + fn test_client_request_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers.insert("*-*", HeaderValue::from_static("o_o")); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec()); + } + + #[test] + fn test_client_request_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\n\r\n" + .as_ref(), + ); + } + #[test] + fn test_client_request_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\n\r\n" + .as_ref(), + ); + } + + #[test] + fn test_server_encode_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let encoder = Server::encode( + Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert!(encoder.is_last()); + } + + #[test] + fn test_server_response_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers + .insert("weird--header", HeaderValue::from_static("")); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: application/json\r\nWeird--Header: \r\n"; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\ndate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\nDate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn parse_header_htabs() { + let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n"); + let parsed = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .expect("parse ok") + .expect("parse complete"); + + assert_eq!(parsed.head.headers["server"], "hello\tworld"); + } + + #[test] + fn test_write_headers_orig_case_empty_value() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "".parse().expect("parse empty")); + let mut orig_cases = HeaderCaseMap::default(); + orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + assert_eq!( + dst, b"X-EmptY:\r\n", + "there should be no space between the colon and CRLF" + ); + } + + #[test] + fn test_write_headers_orig_case_multiple_entries() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "a".parse().unwrap()); + headers.append(&name, "b".parse().unwrap()); + + let mut orig_cases = HeaderCaseMap::default(); + orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); + orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); + } + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_incoming(b: &mut Bencher) { + let mut raw = BytesMut::from( + &b"GET /super_long_uri/and_whatever?what_should_we_talk_about/\ + I_wonder/Hard_to_write_in_an_uri_after_all/you_have_to_make\ + _up_the_punctuation_yourself/how_fun_is_that?test=foo&test1=\ + foo1&test2=foo2&test3=foo3&test4=foo4 HTTP/1.1\r\nHost: \ + hyper.rs\r\nAccept: a lot of things\r\nAccept-Charset: \ + utf8\r\nAccept-Encoding: *\r\nAccess-Control-Allow-\ + Credentials: None\r\nAccess-Control-Allow-Origin: None\r\n\ + Access-Control-Allow-Methods: None\r\nAccess-Control-Allow-\ + Headers: None\r\nContent-Encoding: utf8\r\nContent-Security-\ + Policy: None\r\nContent-Type: text/html\r\nOrigin: hyper\ + \r\nSec-Websocket-Extensions: It looks super important!\r\n\ + Sec-Websocket-Origin: hyper\r\nSec-Websocket-Version: 4.3\r\ + \nStrict-Transport-Security: None\r\nUser-Agent: hyper\r\n\ + X-Content-Duration: None\r\nX-Content-Security-Policy: None\ + \r\nX-DNSPrefetch-Control: None\r\nX-Frame-Options: \ + Something important obviously\r\nX-Requested-With: Nothing\ + \r\n\r\n"[..], + ); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_short(b: &mut Bencher) { + let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..]; + let mut raw = BytesMut::from(s); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + h1_parser_config: Default::default(), + #[cfg(feature = "runtime")] + h1_header_read_timeout: None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_fut: &mut None, + #[cfg(feature = "runtime")] + h1_header_read_timeout_running: &mut false, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + #[cfg(feature = "ffi")] + raw_headers: false, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_headers_preset(b: &mut Bencher) { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let len = 108; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut headers = HeaderMap::new(); + headers.insert("content-length", HeaderValue::from_static("10")); + headers.insert("content-type", HeaderValue::from_static("application/json")); + + b.iter(|| { + let mut vec = Vec::new(); + head.headers = headers.clone(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(vec); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_no_headers(b: &mut Bencher) { + use crate::proto::BodyLength; + + let len = 76; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut vec = Vec::with_capacity(128); + + b.iter(|| { + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(&vec); + + vec.clear(); + }) + } +} |