diff options
Diffstat (limited to 'third_party/rust/hyper/src/proto')
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/conn.rs | 1321 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/date.rs | 82 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/decode.rs | 674 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/dispatch.rs | 702 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/encode.rs | 418 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/io.rs | 907 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/mod.rs | 95 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h1/role.rs | 1835 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/client.rs | 292 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/mod.rs | 263 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/ping.rs | 506 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/h2/server.rs | 439 | ||||
-rw-r--r-- | third_party/rust/hyper/src/proto/mod.rs | 145 |
13 files changed, 7679 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/proto/h1/conn.rs b/third_party/rust/hyper/src/proto/h1/conn.rs new file mode 100644 index 0000000000..c8b355cd63 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/conn.rs @@ -0,0 +1,1321 @@ +use std::fmt; +use std::io::{self}; +use std::marker::PhantomData; + +use bytes::{Buf, Bytes}; +use http::header::{HeaderValue, CONNECTION}; +use http::{HeaderMap, Method, Version}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::io::Buffered; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; +use crate::common::{task, Pin, Poll, Unpin}; +use crate::headers::connection_keep_alive; +use crate::proto::{BodyLength, DecodedLength, MessageHead}; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// This handles a connection, which will have been established over an +/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Transaction`s over HTTP. +/// +/// The connection will determine when a message begins and ends as well as +/// determine if this connection can be kept alive after the message, +/// or if it is complete. +pub(crate) struct Conn<I, B, T> { + io: Buffered<I, EncodedBuf<B>>, + state: State, + _marker: PhantomData<fn(T)>, +} + +impl<I, B, T> Conn<I, B, T> +where + I: AsyncRead + AsyncWrite + Unpin, + B: Buf, + T: Http1Transaction, +{ + pub fn new(io: I) -> Conn<I, B, T> { + Conn { + io: Buffered::new(io), + state: State { + allow_half_close: false, + cached_headers: None, + error: None, + keep_alive: KA::Busy, + method: None, + title_case_headers: false, + notify_read: false, + reading: Reading::Init, + writing: Writing::Init, + upgrade: None, + // We assume a modern world where the remote speaks HTTP/1.1. + // If they tell us otherwise, we'll downgrade in `read_head`. + version: Version::HTTP_11, + }, + _marker: PhantomData, + } + } + + pub fn set_flush_pipeline(&mut self, enabled: bool) { + self.io.set_flush_pipeline(enabled); + } + + pub fn set_max_buf_size(&mut self, max: usize) { + self.io.set_max_buf_size(max); + } + + pub fn set_read_buf_exact_size(&mut self, sz: usize) { + self.io.set_read_buf_exact_size(sz); + } + + pub fn set_write_strategy_flatten(&mut self) { + self.io.set_write_strategy_flatten(); + } + + pub fn set_title_case_headers(&mut self) { + self.state.title_case_headers = true; + } + + pub(crate) fn set_allow_half_close(&mut self) { + self.state.allow_half_close = true; + } + + pub fn into_inner(self) -> (I, Bytes) { + self.io.into_inner() + } + + pub fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> { + self.state.upgrade.take() + } + + pub fn is_read_closed(&self) -> bool { + self.state.is_read_closed() + } + + pub fn is_write_closed(&self) -> bool { + self.state.is_write_closed() + } + + pub fn can_read_head(&self) -> bool { + match self.state.reading { + Reading::Init => { + if T::should_read_first() { + true + } else { + match self.state.writing { + Writing::Init => false, + _ => true, + } + } + } + _ => false, + } + } + + pub fn can_read_body(&self) -> bool { + match self.state.reading { + Reading::Body(..) | Reading::Continue(..) => true, + _ => false, + } + } + + fn should_error_on_eof(&self) -> bool { + // If we're idle, it's probably just the connection closing gracefully. + T::should_error_on_parse_eof() && !self.state.is_idle() + } + + fn has_h2_prefix(&self) -> bool { + let read_buf = self.io.read_buf(); + read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE + } + + pub(super) fn poll_read_head( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> { + debug_assert!(self.can_read_head()); + trace!("Conn::read_head"); + + let msg = match ready!(self.io.parse::<T>( + cx, + ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + } + )) { + Ok(msg) => msg, + Err(e) => return self.on_read_head_error(e), + }; + + // Note: don't deconstruct `msg` into local variables, it appears + // the optimizer doesn't remove the extra copies. + + debug!("incoming body is {}", msg.decode); + + self.state.busy(); + self.state.keep_alive &= msg.keep_alive; + self.state.version = msg.head.version; + + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + + if msg.decode == DecodedLength::ZERO { + if msg.expect_continue { + debug!("ignoring expect-continue since body is empty"); + } + self.state.reading = Reading::KeepAlive; + if !T::should_read_first() { + self.try_keep_alive(cx); + } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); + } else { + self.state.reading = Reading::Body(Decoder::new(msg.decode)); + } + + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) + } + + fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); + if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface + debug!( + "parse error ({}) with {} bytes", + e, + self.io.read_buf().len() + ); + match self.on_parse_error(e) { + Ok(()) => Poll::Pending, // XXX: wat? + Err(e) => Poll::Ready(Some(Err(e))), + } + } else { + debug!("read eof"); + self.close_write(); + Poll::Ready(None) + } + } + + pub fn poll_read_body( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<io::Result<Bytes>>> { + debug_assert!(self.can_read_body()); + + let (reading, ret) = match self.state.reading { + Reading::Body(ref mut decoder) => { + match decoder.decode(cx, &mut self.io) { + Poll::Ready(Ok(slice)) => { + let (reading, chunk) = if decoder.is_eof() { + debug!("incoming body completed"); + ( + Reading::KeepAlive, + if !slice.is_empty() { + Some(Ok(slice)) + } else { + None + }, + ) + } else if slice.is_empty() { + error!("incoming body unexpectedly ended"); + // This should be unreachable, since all 3 decoders + // either set eof=true or return an Err when reading + // an empty slice... + (Reading::Closed, None) + } else { + return Poll::Ready(Some(Ok(slice))); + }; + (reading, Poll::Ready(chunk)) + } + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + debug!("incoming body decode error: {}", e); + (Reading::Closed, Poll::Ready(Some(Err(e)))) + } + } + } + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), + }; + + self.state.reading = reading; + self.try_keep_alive(cx); + ret + } + + pub fn wants_read_again(&mut self) -> bool { + let ret = self.state.notify_read; + self.state.notify_read = false; + ret + } + + pub fn poll_read_keep_alive(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body()); + + if self.is_read_closed() { + Poll::Pending + } else if self.is_mid_message() { + self.mid_message_detect_eof(cx) + } else { + self.require_empty_read(cx) + } + } + + fn is_mid_message(&self) -> bool { + match (&self.state.reading, &self.state.writing) { + (&Reading::Init, &Writing::Init) => false, + _ => true, + } + } + + // This will check to make sure the io object read is empty. + // + // This should only be called for Clients wanting to enter the idle + // state. + fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(!self.is_mid_message()); + debug_assert!(T::is_client()); + + if !self.io.read_buf().is_empty() { + debug!("received an unexpected {} bytes", self.io.read_buf().len()); + return Poll::Ready(Err(crate::Error::new_unexpected_message())); + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + let ret = if self.should_error_on_eof() { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + trace!("found EOF on idle connection, closing"); + Poll::Ready(Ok(())) + }; + + // order is important: should_error needs state BEFORE close_read + self.state.close_read(); + return ret; + } + + debug!( + "received unexpected {} bytes on an idle connection", + num_read + ); + Poll::Ready(Err(crate::Error::new_unexpected_message())) + } + + fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); + debug_assert!(self.is_mid_message()); + + if self.state.allow_half_close || !self.io.read_buf().is_empty() { + return Poll::Pending; + } + + let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; + + if num_read == 0 { + trace!("found unexpected EOF on busy connection: {:?}", self.state); + self.state.close_read(); + Poll::Ready(Err(crate::Error::new_incomplete())) + } else { + Poll::Ready(Ok(())) + } + } + + fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + debug_assert!(!self.state.is_read_closed()); + + let result = ready!(self.io.poll_read_from_io(cx)); + Poll::Ready(result.map_err(|e| { + trace!("force_io_read; io error = {:?}", e); + self.state.close(); + e + })) + } + + fn maybe_notify(&mut self, cx: &mut task::Context<'_>) { + // its possible that we returned NotReady from poll() without having + // exhausted the underlying Io. We would have done this when we + // determined we couldn't keep reading until we knew how writing + // would finish. + + match self.state.reading { + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } + Reading::Init => (), + }; + + match self.state.writing { + Writing::Body(..) => return, + Writing::Init | Writing::KeepAlive | Writing::Closed => (), + } + + if !self.io.is_read_blocked() { + if self.io.read_buf().is_empty() { + match self.io.poll_read_from_io(cx) { + Poll::Ready(Ok(n)) => { + if n == 0 { + trace!("maybe_notify; read eof"); + if self.state.is_idle() { + self.state.close(); + } else { + self.close_read() + } + return; + } + } + Poll::Pending => { + trace!("maybe_notify; read_from_io blocked"); + return; + } + Poll::Ready(Err(e)) => { + trace!("maybe_notify; read_from_io error: {}", e); + self.state.close(); + self.state.error = Some(crate::Error::new_io(e)); + } + } + } + self.state.notify_read = true; + } + } + + fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) { + self.state.try_keep_alive::<T>(); + self.maybe_notify(cx); + } + + pub fn can_write_head(&self) -> bool { + if !T::should_read_first() { + if let Reading::Closed = self.state.reading { + return false; + } + } + match self.state.writing { + Writing::Init => true, + _ => false, + } + } + + pub fn can_write_body(&self) -> bool { + match self.state.writing { + Writing::Body(..) => true, + Writing::Init | Writing::KeepAlive | Writing::Closed => false, + } + } + + pub fn can_buffer_body(&self) -> bool { + self.io.can_buffer() + } + + pub fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) { + if let Some(encoder) = self.encode_head(head, body) { + self.state.writing = if !encoder.is_eof() { + Writing::Body(encoder) + } else if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + + pub fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) { + if let Some(encoder) = + self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64))) + { + let is_last = encoder.is_last(); + // Make sure we don't write a body if we weren't actually allowed + // to do so, like because its a HEAD request. + if !encoder.is_eof() { + encoder.danger_full_buf(body, self.io.write_buf()); + } + self.state.writing = if is_last { + Writing::Closed + } else { + Writing::KeepAlive + } + } + } + + fn encode_head( + &mut self, + mut head: MessageHead<T::Outgoing>, + body: Option<BodyLength>, + ) -> Option<Encoder> { + debug_assert!(self.can_write_head()); + + if !T::should_read_first() { + self.state.busy(); + } + + self.enforce_version(&mut head); + + let buf = self.io.headers_buf(); + match T::encode( + Encode { + head: &mut head, + body, + keep_alive: self.state.wants_keep_alive(), + req_method: &mut self.state.method, + title_case_headers: self.state.title_case_headers, + }, + buf, + ) { + Ok(encoder) => { + debug_assert!(self.state.cached_headers.is_none()); + debug_assert!(head.headers.is_empty()); + self.state.cached_headers = Some(head.headers); + Some(encoder) + } + Err(err) => { + self.state.error = Some(err); + self.state.writing = Writing::Closed; + None + } + } + } + + // Fix keep-alives when Connection: keep-alive header is not present + fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) { + let outgoing_is_keep_alive = head + .headers + .get(CONNECTION) + .map(connection_keep_alive) + .unwrap_or(false); + + if !outgoing_is_keep_alive { + match head.version { + // If response is version 1.0 and keep-alive is not present in the response, + // disable keep-alive so the server closes the connection + Version::HTTP_10 => self.state.disable_keep_alive(), + // If response is version 1.1 and keep-alive is wanted, add + // Connection: keep-alive header when not present + Version::HTTP_11 => { + if self.state.wants_keep_alive() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } + _ => (), + } + } + } + + // If we know the remote speaks an older version, we try to fix up any messages + // to work with our older peer. + fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) { + if let Version::HTTP_10 = self.state.version { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = Version::HTTP_10; + } + // If the remote speaks HTTP/1.1, then it *should* be fine with + // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let + // the user's headers be. + } + + pub fn write_body(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + self.io.buffer(encoder.encode(chunk)); + + if encoder.is_eof() { + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } else { + return; + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub fn write_body_and_end(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref encoder) => { + let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf()); + if can_keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub fn end_body(&mut self) { + debug_assert!(self.can_write_body()); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + // end of stream, that means we should try to eof + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + } + Err(_not_eof) => Writing::Closed, + } + } + _ => return, + }; + + self.state.writing = state; + } + + // When we get a parse error, depending on what side we are, we might be able + // to write a response before closing the connection. + // + // - Client: there is nothing we can do + // - Server: if Response hasn't been written yet, we can send a 4xx response + fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> { + if let Writing::Init = self.state.writing { + if self.has_h2_prefix() { + return Err(crate::Error::new_version_h2()); + } + if let Some(msg) = T::on_error(&err) { + // Drop the cached headers so as to not trigger a debug + // assert in `write_head`... + self.state.cached_headers.take(); + self.write_head(msg, None); + self.state.error = Some(err); + return Ok(()); + } + } + + // fallback is pass the error back up + Err(err) + } + + pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + ready!(Pin::new(&mut self.io).poll_flush(cx))?; + self.try_keep_alive(cx); + trace!("flushed({}): {:?}", T::LOG, self.state); + Poll::Ready(Ok(())) + } + + pub fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) { + Ok(()) => { + trace!("shut down IO complete"); + Poll::Ready(Ok(())) + } + Err(e) => { + debug!("error shutting down IO: {}", e); + Poll::Ready(Err(e)) + } + } + } + + /// If the read side can be cheaply drained, do so. Otherwise, close. + pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) { + let _ = self.poll_read_body(cx); + + // If still in Reading::Body, just give up + match self.state.reading { + Reading::Init | Reading::KeepAlive => { + trace!("body drained"); + return; + } + _ => self.close_read(), + } + } + + pub fn close_read(&mut self) { + self.state.close_read(); + } + + pub fn close_write(&mut self) { + self.state.close_write(); + } + + pub fn disable_keep_alive(&mut self) { + if self.state.is_idle() { + trace!("disable_keep_alive; closing idle connection"); + self.state.close(); + } else { + trace!("disable_keep_alive; in-progress connection"); + self.state.disable_keep_alive(); + } + } + + pub fn take_error(&mut self) -> crate::Result<()> { + if let Some(err) = self.state.error.take() { + Err(err) + } else { + Ok(()) + } + } + + pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + trace!("{}: prepare possible HTTP upgrade", T::LOG); + self.state.prepare_upgrade() + } +} + +impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("state", &self.state) + .field("io", &self.io) + .finish() + } +} + +// B and T are never pinned +impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {} + +struct State { + allow_half_close: bool, + /// Re-usable HeaderMap to reduce allocating new ones. + cached_headers: Option<HeaderMap>, + /// If an error occurs when there wasn't a direct way to return it + /// back to the user, this is set. + error: Option<crate::Error>, + /// Current keep-alive status. + keep_alive: KA, + /// If mid-message, the HTTP Method that started it. + /// + /// This is used to know things such as if the message can include + /// a body or not. + method: Option<Method>, + title_case_headers: bool, + /// Set to true when the Dispatcher should poll read operations + /// again. See the `maybe_notify` method for more. + notify_read: bool, + /// State of allowed reads + reading: Reading, + /// State of allowed writes + writing: Writing, + /// An expected pending HTTP upgrade. + upgrade: Option<crate::upgrade::Pending>, + /// Either HTTP/1.0 or 1.1 connection + version: Version, +} + +#[derive(Debug)] +enum Reading { + Init, + Continue(Decoder), + Body(Decoder), + KeepAlive, + Closed, +} + +enum Writing { + Init, + Body(Encoder), + KeepAlive, + Closed, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("State"); + builder + .field("reading", &self.reading) + .field("writing", &self.writing) + .field("keep_alive", &self.keep_alive); + + // Only show error field if it's interesting... + if let Some(ref error) = self.error { + builder.field("error", error); + } + + if self.allow_half_close { + builder.field("allow_half_close", &true); + } + + // Purposefully leaving off other fields.. + + builder.finish() + } +} + +impl fmt::Debug for Writing { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Writing::Init => f.write_str("Init"), + Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(), + Writing::KeepAlive => f.write_str("KeepAlive"), + Writing::Closed => f.write_str("Closed"), + } + } +} + +impl std::ops::BitAndAssign<bool> for KA { + fn bitand_assign(&mut self, enabled: bool) { + if !enabled { + trace!("remote disabling keep-alive"); + *self = KA::Disabled; + } + } +} + +#[derive(Clone, Copy, Debug)] +enum KA { + Idle, + Busy, + Disabled, +} + +impl Default for KA { + fn default() -> KA { + KA::Busy + } +} + +impl KA { + fn idle(&mut self) { + *self = KA::Idle; + } + + fn busy(&mut self) { + *self = KA::Busy; + } + + fn disable(&mut self) { + *self = KA::Disabled; + } + + fn status(&self) -> KA { + *self + } +} + +impl State { + fn close(&mut self) { + trace!("State::close()"); + self.reading = Reading::Closed; + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn close_read(&mut self) { + trace!("State::close_read()"); + self.reading = Reading::Closed; + self.keep_alive.disable(); + } + + fn close_write(&mut self) { + trace!("State::close_write()"); + self.writing = Writing::Closed; + self.keep_alive.disable(); + } + + fn wants_keep_alive(&self) -> bool { + if let KA::Disabled = self.keep_alive.status() { + false + } else { + true + } + } + + fn try_keep_alive<T: Http1Transaction>(&mut self) { + match (&self.reading, &self.writing) { + (&Reading::KeepAlive, &Writing::KeepAlive) => { + if let KA::Busy = self.keep_alive.status() { + self.idle::<T>(); + } else { + trace!( + "try_keep_alive({}): could keep-alive, but status = {:?}", + T::LOG, + self.keep_alive + ); + self.close(); + } + } + (&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => { + self.close() + } + _ => (), + } + } + + fn disable_keep_alive(&mut self) { + self.keep_alive.disable() + } + + fn busy(&mut self) { + if let KA::Disabled = self.keep_alive.status() { + return; + } + self.keep_alive.busy(); + } + + fn idle<T: Http1Transaction>(&mut self) { + debug_assert!(!self.is_idle(), "State::idle() called while idle"); + + self.method = None; + self.keep_alive.idle(); + if self.is_idle() { + self.reading = Reading::Init; + self.writing = Writing::Init; + + // !T::should_read_first() means Client. + // + // If Client connection has just gone idle, the Dispatcher + // should try the poll loop one more time, so as to poll the + // pending requests stream. + if !T::should_read_first() { + self.notify_read = true; + } + } else { + self.close(); + } + } + + fn is_idle(&self) -> bool { + if let KA::Idle = self.keep_alive.status() { + true + } else { + false + } + } + + fn is_read_closed(&self) -> bool { + match self.reading { + Reading::Closed => true, + _ => false, + } + } + + fn is_write_closed(&self) -> bool { + match self.writing { + Writing::Closed => true, + _ => false, + } + } + + fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade { + let (tx, rx) = crate::upgrade::pending(); + self.upgrade = Some(tx); + rx + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "nightly")] + #[bench] + fn bench_read_head_short(b: &mut ::test::Bencher) { + use super::*; + let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; + let len = s.len(); + b.bytes = len as u64; + + // an empty IO, we'll be skipping and using the read buffer anyways + let io = tokio_test::io::Builder::new().build(); + let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); + *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); + conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); + + let mut rt = tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .unwrap(); + + b.iter(|| { + rt.block_on(futures_util::future::poll_fn(|cx| { + match conn.poll_read_head(cx) { + Poll::Ready(Some(Ok(x))) => { + ::test::black_box(&x); + let mut headers = x.0.headers; + headers.clear(); + conn.state.cached_headers = Some(headers); + } + f => panic!("expected Ready(Some(Ok(..))): {:?}", f), + } + + conn.io.read_buf_mut().reserve(1); + unsafe { + conn.io.read_buf_mut().set_len(len); + } + conn.state.reading = Reading::Init; + Poll::Ready(()) + })); + }); + } + + /* + //TODO: rewrite these using dispatch... someday... + use futures::{Async, Future, Stream, Sink}; + use futures::future; + + use proto::{self, ClientTransaction, MessageHead, ServerTransaction}; + use super::super::Encoder; + use mock::AsyncIo; + + use super::{Conn, Decoder, Reading, Writing}; + use ::uri::Uri; + + use std::str::FromStr; + + #[test] + fn test_conn_init_read() { + let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec(); + let len = good_message.len(); + let io = AsyncIo::new_buf(good_message, len); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + + match conn.poll().unwrap() { + Async::Ready(Some(Frame::Message { message, body: false })) => { + assert_eq!(message, MessageHead { + subject: ::proto::RequestLine(::Get, Uri::from_str("/").unwrap()), + .. MessageHead::default() + }) + }, + f => panic!("frame is not Frame::Message: {:?}", f) + } + } + + #[test] + fn test_conn_parse_partial() { + let _: Result<(), ()> = future::lazy(|| { + let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec(); + let io = AsyncIo::new_buf(good_message, 10); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + assert!(conn.poll().unwrap().is_not_ready()); + conn.io.io_mut().block_in(50); + let async = conn.poll().unwrap(); + assert!(async.is_ready()); + match async { + Async::Ready(Some(Frame::Message { .. })) => (), + f => panic!("frame is not Message: {:?}", f), + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_init_read_eof_idle() { + let io = AsyncIo::new_buf(vec![], 1); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_idle_partial_parse() { + let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.idle(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + } + + #[test] + fn test_conn_init_read_eof_busy() { + let _: Result<(), ()> = future::lazy(|| { + // server ignores + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("unexpected frame: {:?}", other) + } + + // client + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}, + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_finish_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_eof(); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + conn.state.reading = Reading::Body(Decoder::length(0)); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_message_empty_body_read_eof() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ClientTransaction>::new(io); + conn.state.busy(); + conn.state.writing = Writing::KeepAlive; + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: false, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // conn eofs, but tokio-proto will call poll() again, before calling flush() + // the conn eof in this case is perfectly fine + + match conn.poll() { + Ok(Async::Ready(None)) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_read_body_end() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.busy(); + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Message { body: true, .. }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: Some(_) }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + // When the body is done, `poll` MUST return a `Body` frame with chunk set to `None` + match conn.poll() { + Ok(Async::Ready(Some(Frame::Body { chunk: None }))) => (), + other => panic!("unexpected frame: {:?}", other) + } + + match conn.poll() { + Ok(Async::NotReady) => (), + other => panic!("unexpected frame: {:?}", other) + } + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_closed_read() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.poll().unwrap() { + Async::Ready(None) => {}, + other => panic!("frame is not None: {:?}", other) + } + } + + #[test] + fn test_conn_body_write_length() { + let _ = pretty_env_logger::try_init(); + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096; + conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64)); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; max].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'b'; 1024 * 8].into()) }).unwrap().is_not_ready()); + + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(1024 * 3); + assert!(conn.poll_complete().unwrap().is_not_ready()); + conn.io.io_mut().block_in(max * 2); + assert!(conn.poll_complete().unwrap().is_ready()); + + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'c'; 1024 * 8].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_write_chunked() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::chunked()); + + assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'x'; 8192].into()) }).unwrap().is_ready()); + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_body_flush() { + let _: Result<(), ()> = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::Body(Encoder::length(1024 * 1024)); + assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready()); + assert!(!conn.can_buffer_body()); + conn.io.io_mut().block_in(1024 * 1024 * 5); + assert!(conn.poll_complete().unwrap().is_ready()); + assert!(conn.can_buffer_body()); + assert!(conn.io.io_mut().flushed()); + + Ok(()) + }).wait(); + } + + #[test] + fn test_conn_parking() { + use std::sync::Arc; + use futures::executor::Notify; + use futures::executor::NotifyHandle; + + struct Car { + permit: bool, + } + impl Notify for Car { + fn notify(&self, _id: usize) { + assert!(self.permit, "unparked without permit"); + } + } + + fn car(permit: bool) -> NotifyHandle { + Arc::new(Car { + permit: permit, + }).into() + } + + // test that once writing is done, unparks + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(true), 0).unwrap(); + + + // test that flushing when not waiting on read doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + + + // test that flushing and writing isn't done doesn't unpark + let f = future::lazy(|| { + let io = AsyncIo::new_buf(vec![], 4096); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.reading = Reading::KeepAlive; + assert!(conn.poll().unwrap().is_not_ready()); + conn.state.writing = Writing::Body(Encoder::length(5_000)); + assert!(conn.poll_complete().unwrap().is_ready()); + Ok::<(), ()>(()) + }); + ::futures::executor::spawn(f).poll_future_notify(&car(false), 0).unwrap(); + } + + #[test] + fn test_conn_closed_write() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.close(); + + match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) { + Err(_e) => {}, + other => panic!("did not return Err: {:?}", other) + } + + assert!(conn.state.is_write_closed()); + } + + #[test] + fn test_conn_write_empty_chunk() { + let io = AsyncIo::new_buf(vec![], 0); + let mut conn = Conn::<_, proto::Bytes, ServerTransaction>::new(io); + conn.state.writing = Writing::KeepAlive; + + assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready()); + assert!(conn.start_send(Frame::Body { chunk: Some(Vec::new().into()) }).unwrap().is_ready()); + conn.start_send(Frame::Body { chunk: Some(vec![b'a'].into()) }).unwrap_err(); + } + */ +} diff --git a/third_party/rust/hyper/src/proto/h1/date.rs b/third_party/rust/hyper/src/proto/h1/date.rs new file mode 100644 index 0000000000..3e972d6e00 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/date.rs @@ -0,0 +1,82 @@ +use std::cell::RefCell; +use std::fmt::{self, Write}; +use std::str; + +use http::header::HeaderValue; +use time::{self, Duration}; + +// "Sun, 06 Nov 1994 08:49:37 GMT".len() +pub const DATE_VALUE_LENGTH: usize = 29; + +pub fn extend(dst: &mut Vec<u8>) { + CACHED.with(|cache| { + dst.extend_from_slice(cache.borrow().buffer()); + }) +} + +pub fn update() { + CACHED.with(|cache| { + cache.borrow_mut().check(); + }) +} + +pub(crate) fn update_and_header_value() -> HeaderValue { + CACHED.with(|cache| { + let mut cache = cache.borrow_mut(); + cache.check(); + HeaderValue::from_bytes(cache.buffer()).expect("Date format should be valid HeaderValue") + }) +} + +struct CachedDate { + bytes: [u8; DATE_VALUE_LENGTH], + pos: usize, + next_update: time::Timespec, +} + +thread_local!(static CACHED: RefCell<CachedDate> = RefCell::new(CachedDate::new())); + +impl CachedDate { + fn new() -> Self { + let mut cache = CachedDate { + bytes: [0; DATE_VALUE_LENGTH], + pos: 0, + next_update: time::Timespec::new(0, 0), + }; + cache.update(time::get_time()); + cache + } + + fn buffer(&self) -> &[u8] { + &self.bytes[..] + } + + fn check(&mut self) { + let now = time::get_time(); + if now > self.next_update { + self.update(now); + } + } + + fn update(&mut self, now: time::Timespec) { + self.pos = 0; + let _ = write!(self, "{}", time::at_utc(now).rfc822()); + debug_assert!(self.pos == DATE_VALUE_LENGTH); + self.next_update = now + Duration::seconds(1); + self.next_update.nsec = 0; + } +} + +impl fmt::Write for CachedDate { + fn write_str(&mut self, s: &str) -> fmt::Result { + let len = s.len(); + self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); + self.pos += len; + Ok(()) + } +} + +#[test] +fn test_date_len() { + assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); +} diff --git a/third_party/rust/hyper/src/proto/h1/decode.rs b/third_party/rust/hyper/src/proto/h1/decode.rs new file mode 100644 index 0000000000..beaf9aff7a --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/decode.rs @@ -0,0 +1,674 @@ +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::usize; + +use bytes::Bytes; + +use crate::common::{task, Poll}; + +use super::io::MemRead; +use super::DecodedLength; + +use self::Kind::{Chunked, Eof, Length}; + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Clone, PartialEq)] +pub struct Decoder { + kind: Kind, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// The bool tracks when EOF is seen on the transport. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof(bool), +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl Decoder { + // constructors + + pub fn length(x: u64) -> Decoder { + Decoder { + kind: Kind::Length(x), + } + } + + pub fn chunked() -> Decoder { + Decoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub fn eof() -> Decoder { + Decoder { + kind: Kind::Eof(false), + } + } + + pub(super) fn new(len: DecodedLength) -> Self { + match len { + DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CLOSE_DELIMITED => Decoder::eof(), + length => Decoder::length(length.danger_len()), + } + } + + // methods + + pub fn is_eof(&self) -> bool { + match self.kind { + Length(0) | Chunked(ChunkedState::End, _) | Eof(true) => true, + _ => false, + } + } + + pub fn decode<R: MemRead>( + &mut self, + cx: &mut task::Context<'_>, + body: &mut R, + ) -> Poll<Result<Bytes, io::Error>> { + trace!("decode; state={:?}", self.kind); + match self.kind { + Length(ref mut remaining) => { + if *remaining == 0 { + Poll::Ready(Ok(Bytes::new())) + } else { + let to_read = *remaining as usize; + let buf = ready!(body.read_mem(cx, to_read))?; + let num = buf.as_ref().len() as u64; + if num > *remaining { + *remaining = 0; + } else if num == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } else { + *remaining -= num; + } + Poll::Ready(Ok(buf)) + } + } + Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = ready!(state.step(cx, body, size, &mut buf))?; + if *state == ChunkedState::End { + trace!("end of chunked"); + return Poll::Ready(Ok(Bytes::new())); + } + if let Some(buf) = buf { + return Poll::Ready(Ok(buf)); + } + } + } + Eof(ref mut is_eof) => { + if *is_eof { + Poll::Ready(Ok(Bytes::new())) + } else { + // 8192 chosen because its about 2 packets, there probably + // won't be that much available, so don't have MemReaders + // allocate buffers to big + body.read_mem(cx, 8192).map_ok(|slice| { + *is_eof = slice.is_empty(); + slice + }) + } + } + } + } + + #[cfg(test)] + async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> { + futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await + } +} + +impl fmt::Debug for Decoder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.kind, f) + } +} + +macro_rules! byte ( + ($rdr:ident, $cx:expr) => ({ + let buf = ready!($rdr.read_mem($cx, 1))?; + if !buf.is_empty() { + buf[0] + } else { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, + "unexpected EOF during chunk size line"))); + } + }) +); + +impl ChunkedState { + fn step<R: MemRead>( + &self, + cx: &mut task::Context<'_>, + body: &mut R, + size: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(cx, body, size), + SizeLws => ChunkedState::read_size_lws(cx, body), + Extension => ChunkedState::read_extension(cx, body), + SizeLf => ChunkedState::read_size_lf(cx, body, *size), + Body => ChunkedState::read_body(cx, body, size, buf), + BodyCr => ChunkedState::read_body_cr(cx, body), + BodyLf => ChunkedState::read_body_lf(cx, body), + EndCr => ChunkedState::read_end_cr(cx, body), + EndLf => ChunkedState::read_end_lf(cx, body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + fn read_size<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: &mut u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Read chunk hex size"); + let radix = 16; + match byte!(rdr, cx) { + b @ b'0'..=b'9' => { + *size *= radix; + *size += (b - b'0') as u64; + } + b @ b'a'..=b'f' => { + *size *= radix; + *size += (b + 10 - b'a') as u64; + } + b @ b'A'..=b'F' => { + *size *= radix; + *size += (b + 10 - b'A') as u64; + } + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + } + Poll::Ready(Ok(ChunkedState::Size)) + } + fn read_size_lws<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_size_lws"); + match byte!(rdr, cx) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("read_extension"); + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + size: u64, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunk size is {:?}", size); + match byte!(rdr, cx) { + b'\n' => { + if size == 0 { + Poll::Ready(Ok(ChunkedState::EndCr)) + } else { + debug!("incoming chunked header: {0:#X} ({0} bytes)", size); + Poll::Ready(Ok(ChunkedState::Body)) + } + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + rem: &mut u64, + buf: &mut Option<Bytes>, + ) -> Poll<Result<ChunkedState, io::Error>> { + trace!("Chunked read, remaining={:?}", rem); + + // cap remaining bytes at the max capacity of usize + let rem_cap = match *rem { + r if r > usize::MAX as u64 => usize::MAX, + r => r as usize, + }; + + let to_read = rem_cap; + let slice = ready!(rdr.read_mem(cx, to_read))?; + let count = slice.len(); + + if count == 0 { + *rem = 0; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + IncompleteBody, + ))); + } + *buf = Some(slice); + *rem -= count as u64; + + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + fn read_body_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + + fn read_end_cr<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + ))), + } + } + fn read_end_lf<R: MemRead>( + cx: &mut task::Context<'_>, + rdr: &mut R, + ) -> Poll<Result<ChunkedState, io::Error>> { + match byte!(rdr, cx) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[derive(Debug)] +struct IncompleteBody; + +impl fmt::Display for IncompleteBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "end of file before message length reached") + } +} + +impl StdError for IncompleteBody {} + +#[cfg(test)] +mod tests { + use super::*; + use std::pin::Pin; + use std::time::Duration; + use tokio::io::AsyncRead; + + impl<'a> MemRead for &'a [u8] { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + if n > 0 { + let (a, b) = self.split_at(n); + let buf = Bytes::copy_from_slice(a); + *self = b; + Poll::Ready(Ok(buf)) + } else { + Poll::Ready(Ok(Bytes::new())) + } + } + } + + impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let mut v = vec![0; len]; + let n = ready!(Pin::new(self).poll_read(cx, &mut v)?); + Poll::Ready(Ok(Bytes::copy_from_slice(&v[..n]))) + } + } + + #[cfg(feature = "nightly")] + impl MemRead for Bytes { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + let n = std::cmp::min(len, self.len()); + let ret = self.split_to(n); + Poll::Ready(Ok(ret)) + } + } + + /* + use std::io; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use futures::{Async, Poll}; + use bytes::{BytesMut, Bytes}; + use crate::mock::AsyncIo; + */ + + #[tokio::test] + async fn test_read_chunk_size() { + use std::io::ErrorKind::{InvalidInput, UnexpectedEof}; + + async fn read(s: &str) -> u64 { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + let desc = format!("read_size failed for {:?}", s); + state = result.expect(desc.as_str()); + if state == ChunkedState::Body || state == ChunkedState::EndCr { + break; + } + } + size + } + + async fn read_err(s: &str, expected_err: io::ErrorKind) { + let mut state = ChunkedState::Size; + let rdr = &mut s.as_bytes(); + let mut size = 0; + loop { + let result = + futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None)) + .await; + state = match result { + Ok(s) => s, + Err(e) => { + assert!( + expected_err == e.kind(), + "Reading {:?}, expected {:?}, but got {:?}", + s, + expected_err, + e.kind() + ); + return; + } + }; + if state == ChunkedState::Body || state == ChunkedState::End { + panic!(format!("Was Ok. Expected Err for {:?}", s)); + } + } + } + + assert_eq!(1, read("1\r\n").await); + assert_eq!(1, read("01\r\n").await); + assert_eq!(0, read("0\r\n").await); + assert_eq!(0, read("00\r\n").await); + assert_eq!(10, read("A\r\n").await); + assert_eq!(10, read("a\r\n").await); + assert_eq!(255, read("Ff\r\n").await); + assert_eq!(255, read("Ff \r\n").await); + // Missing LF or CRLF + read_err("F\rF", InvalidInput).await; + read_err("F", UnexpectedEof).await; + // Invalid hex digit + read_err("X\r\n", InvalidInput).await; + read_err("1X\r\n", InvalidInput).await; + read_err("-\r\n", InvalidInput).await; + read_err("-1\r\n", InvalidInput).await; + // Acceptable (if not fully valid) extensions do not influence the size + assert_eq!(1, read("1;extension\r\n").await); + assert_eq!(10, read("a;ext name=value\r\n").await); + assert_eq!(1, read("1;extension;extension2\r\n").await); + assert_eq!(1, read("1;;; ;\r\n").await); + assert_eq!(2, read("2; extension...\r\n").await); + assert_eq!(3, read("3 ; extension=123\r\n").await); + assert_eq!(3, read("3 ;\r\n").await); + assert_eq!(3, read("3 ; \r\n").await); + // Invalid extensions cause an error + read_err("1 invalid extension\r\n", InvalidInput).await; + read_err("1 A\r\n", InvalidInput).await; + read_err("1;no CRLF", UnexpectedEof).await; + } + + #[tokio::test] + async fn test_read_sized_early_eof() { + let mut bytes = &b"foo bar"[..]; + let mut decoder = Decoder::length(10); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_early_eof() { + let mut bytes = &b"\ + 9\r\n\ + foo bar\ + "[..]; + let mut decoder = Decoder::chunked(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + } + + #[tokio::test] + async fn test_read_chunked_single_read() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; + let buf = Decoder::chunked() + .decode_fut(&mut mock_buf) + .await + .expect("decode"); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + } + + #[tokio::test] + async fn test_read_chunked_after_eof() { + let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; + let mut decoder = Decoder::chunked(); + + // normal read + let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); + assert_eq!(16, buf.len()); + let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + + // eof read + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + + // ensure read after eof also returns eof + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + assert_eq!(0, buf.len()); + } + + // perform an async read using a custom buffer size and causing a blocking + // read at the specified byte + async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { + let mut outs = Vec::new(); + + let mut ins = if block_at == 0 { + tokio_test::io::Builder::new() + .wait(Duration::from_millis(10)) + .read(content) + .build() + } else { + tokio_test::io::Builder::new() + .read(&content[..block_at]) + .wait(Duration::from_millis(10)) + .read(&content[block_at..]) + .build() + }; + + let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + + loop { + let buf = decoder + .decode_fut(&mut ins) + .await + .expect("unexpected decode error"); + if buf.is_empty() { + break; // eof + } + outs.extend(buf.as_ref()); + } + + String::from_utf8(outs).expect("decode String") + } + + // iterate over the different ways that this async read could go. + // tests blocking a read at each byte along the content - The shotgun approach + async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + let content_len = content.len(); + for block_at in 0..content_len { + let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await; + assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at); + } + } + + #[tokio::test] + async fn test_read_length_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::length(content.len() as u64)).await; + } + + #[tokio::test] + async fn test_read_chunked_async() { + let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; + let expected = "foobar"; + all_async_cases(content, expected, Decoder::chunked()).await; + } + + #[tokio::test] + async fn test_read_eof_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::eof()).await; + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_chunked_1kb(b: &mut test::Bencher) { + let mut rt = new_runtime(); + + const LEN: usize = 1024; + let mut vec = Vec::new(); + vec.extend(format!("{:x}\r\n", LEN).as_bytes()); + vec.extend(&[0; LEN][..]); + vec.extend(b"\r\n"); + let content = Bytes::from(vec); + + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::chunked(); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_decode_length_1kb(b: &mut test::Bencher) { + let mut rt = new_runtime(); + + const LEN: usize = 1024; + let content = Bytes::from(&[0; LEN][..]); + b.bytes = LEN as u64; + + b.iter(|| { + let mut decoder = Decoder::length(LEN as u64); + rt.block_on(async { + let mut raw = content.clone(); + let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + assert_eq!(chunk.len(), LEN); + }); + }); + } + + #[cfg(feature = "nightly")] + fn new_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new() + .enable_all() + .basic_scheduler() + .build() + .expect("rt build") + } +} diff --git a/third_party/rust/hyper/src/proto/h1/dispatch.rs b/third_party/rust/hyper/src/proto/h1/dispatch.rs new file mode 100644 index 0000000000..84ee412c3c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/dispatch.rs @@ -0,0 +1,702 @@ +use std::error::Error as StdError; + +use bytes::{Buf, Bytes}; +use http::{Request, Response, StatusCode}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{Http1Transaction, Wants}; +use crate::body::{Body, Payload}; +use crate::common::{task, Future, Never, Pin, Poll, Unpin}; +use crate::proto::{ + BodyLength, Conn, DecodedLength, Dispatched, MessageHead, RequestHead, RequestLine, + ResponseHead, +}; +use crate::service::HttpService; + +pub(crate) struct Dispatcher<D, Bs: Payload, I, T> { + conn: Conn<I, Bs::Data, T>, + dispatch: D, + body_tx: Option<crate::body::Sender>, + body_rx: Pin<Box<Option<Bs>>>, + is_closing: bool, +} + +pub(crate) trait Dispatch { + type PollItem; + type PollBody; + type PollError; + type RecvItem; + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>; + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>; + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>; + fn should_poll(&self) -> bool; +} + +pub struct Server<S: HttpService<B>, B> { + in_flight: Pin<Box<Option<S::Future>>>, + pub(crate) service: S, +} + +pub struct Client<B> { + callback: Option<crate::client::dispatch::Callback<Request<B>, Response<Body>>>, + rx: ClientRx<B>, + rx_closed: bool, +} + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +impl<D, Bs, I, T> Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: Payload, +{ + pub fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self { + Dispatcher { + conn, + dispatch, + body_tx: None, + body_rx: Box::pin(None), + is_closing: false, + } + } + + pub fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive(); + if self.conn.is_write_closed() { + self.close(); + } + } + + pub fn into_inner(self) -> (I, Bytes, D) { + let (io, buf) = self.conn.into_inner(); + (io, buf, self.dispatch) + } + + /// Run this dispatcher until HTTP says this connection is done, + /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// + /// This is useful for old-style HTTP upgrades, but ignores + /// newer-style upgrade API. + pub(crate) fn poll_without_shutdown( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<crate::Result<()>> + where + Self: Unpin, + { + Pin::new(self).poll_catch(cx, false).map_ok(|ds| { + if let Dispatched::Upgrade(pending) = ds { + pending.manual(); + } + }) + } + + fn poll_catch( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e))?; + Ok(Dispatched::Shutdown) + })) + } + + fn poll_inner( + &mut self, + cx: &mut task::Context<'_>, + should_shutdown: bool, + ) -> Poll<crate::Result<Dispatched>> { + T::update_date(); + + ready!(self.poll_loop(cx))?; + + if self.is_done() { + if let Some(pending) = self.conn.pending_upgrade() { + self.conn.take_error()?; + return Poll::Ready(Ok(Dispatched::Upgrade(pending))); + } else if should_shutdown { + ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?; + } + self.conn.take_error()?; + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Pending + } + } + + fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // Limit the looping on this connection, in case it is ready far too + // often, so that other futures don't starve. + // + // 16 was chosen arbitrarily, as that is number of pipelined requests + // benchmarks often use. Perhaps it should be a config option instead. + for _ in 0..16 { + let _ = self.poll_read(cx)?; + let _ = self.poll_write(cx)?; + let _ = self.poll_flush(cx)?; + + // This could happen if reading paused before blocking on IO, + // such as getting to the end of a framed message, but then + // writing/flushing set the state back to Init. In that case, + // if the read buffer still had bytes, we'd want to try poll_read + // again, or else we wouldn't ever be woken up again. + // + // Using this instead of task::current() and notify() inside + // the Conn is noticeably faster in pipelined benchmarks. + if !self.conn.wants_read_again() { + //break; + return Poll::Ready(Ok(())); + } + } + + trace!("poll_loop yielding (self = {:p})", self); + + task::yield_now(cx).map(|never| match never {}) + } + + fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.conn.can_read_head() { + ready!(self.poll_read_head(cx))?; + } else if let Some(mut body) = self.body_tx.take() { + if self.conn.can_read_body() { + match body.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Err(_canceled)) => { + // user doesn't care about the body + // so we should stop reading + trace!("body receiver dropped before eof, draining or closing"); + self.conn.poll_drain_or_close_read(cx); + continue; + } + } + match self.conn.poll_read_body(cx) { + Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + }, + Poll::Ready(None) => { + // just drop, the body will close automatically + } + Poll::Pending => { + self.body_tx = Some(body); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => { + body.send_error(crate::Error::new_body(e)); + } + } + } else { + // just drop, the body will close automatically + } + } else { + return self.conn.poll_read_keep_alive(cx); + } + } + } + + fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + // can dispatch receive, or does it still care about, an incoming message? + match ready!(self.dispatch.poll_ready(cx)) { + Ok(()) => (), + Err(()) => { + trace!("dispatch no longer receiving messages"); + self.close(); + return Poll::Ready(Ok(())); + } + } + // dispatch is ready for a message, try to read one + match ready!(self.conn.poll_read_head(cx)) { + Some(Ok((head, body_len, wants))) => { + let mut body = match body_len { + DecodedLength::ZERO => Body::empty(), + other => { + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); + self.body_tx = Some(tx); + rx + } + }; + if wants.contains(Wants::UPGRADE) { + body.set_on_upgrade(self.conn.on_upgrade()); + } + self.dispatch.recv_msg(Ok((head, body)))?; + Poll::Ready(Ok(())) + } + Some(Err(err)) => { + debug!("read_head error: {}", err); + self.dispatch.recv_msg(Err(err))?; + // if here, the dispatcher gave the user the error + // somewhere else. we still need to shutdown, but + // not as a second error. + self.close(); + Poll::Ready(Ok(())) + } + None => { + // read eof, the write side will have been closed too unless + // allow_read_close was set to true, in which case just do + // nothing... + debug_assert!(self.conn.is_read_closed()); + if self.conn.is_write_closed() { + self.close(); + } + Poll::Ready(Ok(())) + } + } + } + + fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + loop { + if self.is_closing { + return Poll::Ready(Ok(())); + } else if self.body_rx.is_none() + && self.conn.can_write_head() + && self.dispatch.should_poll() + { + if let Some(msg) = ready!(self.dispatch.poll_msg(cx)) { + let (head, mut body) = msg.map_err(crate::Error::new_user_service)?; + + // Check if the body knows its full data immediately. + // + // If so, we can skip a bit of bookkeeping that streaming + // bodies need to do. + if let Some(full) = crate::body::take_full_data(&mut body) { + self.conn.write_full_msg(head, full); + return Poll::Ready(Ok(())); + } + + let body_type = if body.is_end_stream() { + self.body_rx.set(None); + None + } else { + let btype = body + .size_hint() + .exact() + .map(BodyLength::Known) + .or_else(|| Some(BodyLength::Unknown)); + self.body_rx.set(Some(body)); + btype + }; + self.conn.write_head(head, body_type); + } else { + self.close(); + return Poll::Ready(Ok(())); + } + } else if !self.conn.can_buffer_body() { + ready!(self.poll_flush(cx))?; + } else { + // A new scope is needed :( + if let (Some(mut body), clear_body) = + OptGuard::new(self.body_rx.as_mut()).guard_mut() + { + debug_assert!(!*clear_body, "opt guard defaults to keeping body"); + if !self.conn.can_write_body() { + trace!( + "no more write body allowed, user body is_end_stream = {}", + body.is_end_stream(), + ); + *clear_body = true; + continue; + } + + let item = ready!(body.as_mut().poll_data(cx)); + if let Some(item) = item { + let chunk = item.map_err(|e| { + *clear_body = true; + crate::Error::new_user_body(e) + })?; + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body(); + } else { + self.conn.write_body_and_end(chunk); + } + } else { + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); + } + } else { + *clear_body = true; + self.conn.end_body(); + } + } else { + return Poll::Pending; + } + } + } + } + + fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + self.conn.poll_flush(cx).map_err(|err| { + debug!("error writing: {}", err); + crate::Error::new_body_write(err) + }) + } + + fn close(&mut self) { + self.is_closing = true; + self.conn.close_read(); + self.conn.close_write(); + } + + fn is_done(&self) -> bool { + if self.is_closing { + return true; + } + + let read_done = self.conn.is_read_closed(); + + if !T::should_read_first() && read_done { + // a client that cannot read may was well be done. + true + } else { + let write_done = self.conn.is_write_closed() + || (!self.dispatch.should_poll() && self.body_rx.is_none()); + read_done && write_done + } + } +} + +impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T> +where + D: Dispatch< + PollItem = MessageHead<T::Outgoing>, + PollBody = Bs, + RecvItem = MessageHead<T::Incoming>, + > + Unpin, + D::PollError: Into<Box<dyn StdError + Send + Sync>>, + I: AsyncRead + AsyncWrite + Unpin, + T: Http1Transaction + Unpin, + Bs: Payload, +{ + type Output = crate::Result<Dispatched>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll_catch(cx, true) + } +} + +// ===== impl OptGuard ===== + +/// A drop guard to allow a mutable borrow of an Option while being able to +/// set whether the `Option` should be cleared on drop. +struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool); + +impl<'a, T> OptGuard<'a, T> { + fn new(pin: Pin<&'a mut Option<T>>) -> Self { + OptGuard(pin, false) + } + + fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) { + (self.0.as_mut().as_pin_mut(), &mut self.1) + } +} + +impl<'a, T> Drop for OptGuard<'a, T> { + fn drop(&mut self) { + if self.1 { + self.0.set(None); + } + } +} + +// ===== impl Server ===== + +impl<S, B> Server<S, B> +where + S: HttpService<B>, +{ + pub fn new(service: S) -> Server<S, B> { + Server { + in_flight: Box::pin(None), + service, + } + } + + pub fn into_service(self) -> S { + self.service + } +} + +// Service is never pinned +impl<S: HttpService<B>, B> Unpin for Server<S, B> {} + +impl<S, Bs> Dispatch for Server<S, Body> +where + S: HttpService<Body, ResBody = Bs>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + Bs: Payload, +{ + type PollItem = MessageHead<StatusCode>; + type PollBody = Bs; + type PollError = S::Error; + type RecvItem = RequestHead; + + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> { + let ret = if let Some(ref mut fut) = self.in_flight.as_mut().as_pin_mut() { + let resp = ready!(fut.as_mut().poll(cx)?); + let (parts, body) = resp.into_parts(); + let head = MessageHead { + version: parts.version, + subject: parts.status, + headers: parts.headers, + }; + Poll::Ready(Some(Ok((head, body)))) + } else { + unreachable!("poll_msg shouldn't be called if no inflight"); + }; + + // Since in_flight finished, remove it + self.in_flight.set(None); + ret + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + let (msg, body) = msg?; + let mut req = Request::new(body); + *req.method_mut() = msg.subject.0; + *req.uri_mut() = msg.subject.1; + *req.headers_mut() = msg.headers; + *req.version_mut() = msg.version; + let fut = self.service.call(req); + self.in_flight.set(Some(fut)); + Ok(()) + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + if self.in_flight.is_some() { + Poll::Pending + } else { + self.service.poll_ready(cx).map_err(|_e| { + // FIXME: return error value. + trace!("service closed"); + }) + } + } + + fn should_poll(&self) -> bool { + self.in_flight.is_some() + } +} + +// ===== impl Client ===== + +impl<B> Client<B> { + pub fn new(rx: ClientRx<B>) -> Client<B> { + Client { + callback: None, + rx, + rx_closed: false, + } + } +} + +impl<B> Dispatch for Client<B> +where + B: Payload, +{ + type PollItem = RequestHead; + type PollBody = B; + type PollError = Never; + type RecvItem = ResponseHead; + + fn poll_msg( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Never>>> { + debug_assert!(!self.rx_closed); + match self.rx.poll_next(cx) { + Poll::Ready(Some((req, mut cb))) => { + // check that future hasn't been canceled already + match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("request canceled"); + Poll::Ready(None) + } + Poll::Pending => { + let (parts, body) = req.into_parts(); + let head = RequestHead { + version: parts.version, + subject: RequestLine(parts.method, parts.uri), + headers: parts.headers, + }; + self.callback = Some(cb); + Poll::Ready(Some(Ok((head, body)))) + } + } + } + Poll::Ready(None) => { + // user has dropped sender handle + trace!("client tx closed"); + self.rx_closed = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } + + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> { + match msg { + Ok((msg, body)) => { + if let Some(cb) = self.callback.take() { + let mut res = Response::new(body); + *res.status_mut() = msg.subject; + *res.headers_mut() = msg.headers; + *res.version_mut() = msg.version; + cb.send(Ok(res)); + Ok(()) + } else { + // Getting here is likely a bug! An error should have happened + // in Conn::require_empty_read() before ever parsing a + // full message! + Err(crate::Error::new_unexpected_message()) + } + } + Err(err) => { + if let Some(cb) = self.callback.take() { + cb.send(Err((err, None))); + Ok(()) + } else if !self.rx_closed { + self.rx.close(); + if let Some((req, cb)) = self.rx.try_recv() { + trace!("canceling queued request with connection error: {}", err); + // in this case, the message was never even started, so it's safe to tell + // the user that the request was completely canceled + cb.send(Err((crate::Error::new_canceled().with(err), Some(req)))); + Ok(()) + } else { + Err(err) + } + } else { + Err(err) + } + } + } + } + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> { + match self.callback { + Some(ref mut cb) => match cb.poll_canceled(cx) { + Poll::Ready(()) => { + trace!("callback receiver has dropped"); + Poll::Ready(Err(())) + } + Poll::Pending => Poll::Ready(Ok(())), + }, + None => Poll::Ready(Err(())), + } + } + + fn should_poll(&self) -> bool { + self.callback.is_none() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::h1::ClientTransaction; + use std::time::Duration; + + #[test] + fn client_read_bytes_before_writing_request() { + let _ = pretty_env_logger::try_init(); + + tokio_test::task::spawn(()).enter(|cx, _| { + let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle(); + + // Block at 0 for now, but we will release this response before + // the request is ready to write later... + //let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 0); + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + // First poll is needed to allow tx to send... + assert!(Pin::new(&mut dispatcher).poll(cx).is_pending()); + + // Unblock our IO, which has a response before we've sent request! + // + handle.read(b"HTTP/1.1 200 OK\r\n\r\n"); + + let mut res_rx = tx + .try_send(crate::Request::new(crate::Body::empty())) + .unwrap(); + + tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx)); + let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx)) + .expect_err("callback should send error"); + + match (err.0.kind(), err.1) { + (&crate::error::Kind::Canceled, Some(_)) => (), + other => panic!("expected Canceled, got {:?}", other), + } + }); + } + + #[tokio::test] + async fn body_empty_chunks_ignored() { + let _ = pretty_env_logger::try_init(); + + let io = tokio_test::io::Builder::new() + // no reading or writing, just be blocked for the test... + .wait(Duration::from_secs(5)) + .build(); + + let (mut tx, rx) = crate::client::dispatch::channel(); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); + + // First poll is needed to allow tx to send... + assert!(dispatcher.poll().is_pending()); + + let body = { + let (mut tx, body) = crate::Body::channel(); + tx.try_send_data("".into()).unwrap(); + body + }; + + let _res_rx = tx.try_send(crate::Request::new(body)).unwrap(); + + // Ensure conn.write_body wasn't called with the empty chunk. + // If it is, it will trigger an assertion. + assert!(dispatcher.poll().is_pending()); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/encode.rs b/third_party/rust/hyper/src/proto/h1/encode.rs new file mode 100644 index 0000000000..95b0d82b67 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/encode.rs @@ -0,0 +1,418 @@ +use std::fmt; +use std::io::IoSlice; + +use bytes::buf::ext::{BufExt, Chain, Take}; +use bytes::Buf; + +use super::io::WriteBuf; + +type StaticBuf = &'static [u8]; + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone, PartialEq)] +pub struct Encoder { + kind: Kind, + is_last: bool, +} + +#[derive(Debug)] +pub struct EncodedBuf<B> { + kind: BufKind<B>, +} + +#[derive(Debug)] +pub struct NotEof; + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked, + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when neither Content-Length nor Chunked encoding is set. + /// + /// This is mostly only used with HTTP/1.0 with a length. This kind requires + /// the connection to be closed when the body is finished. + CloseDelimited, +} + +#[derive(Debug)] +enum BufKind<B> { + Exact(B), + Limited(Take<B>), + Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), + ChunkedEnd(StaticBuf), +} + +impl Encoder { + fn new(kind: Kind) -> Encoder { + Encoder { + kind, + is_last: false, + } + } + pub fn chunked() -> Encoder { + Encoder::new(Kind::Chunked) + } + + pub fn length(len: u64) -> Encoder { + Encoder::new(Kind::Length(len)) + } + + pub fn close_delimited() -> Encoder { + Encoder::new(Kind::CloseDelimited) + } + + pub fn is_eof(&self) -> bool { + match self.kind { + Kind::Length(0) => true, + _ => false, + } + } + + pub fn set_last(mut self, is_last: bool) -> Self { + self.is_last = is_last; + self + } + + pub fn is_last(&self) -> bool { + self.is_last + } + + pub fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { + match self.kind { + Kind::Length(0) => Ok(None), + Kind::Chunked => Ok(Some(EncodedBuf { + kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), + })), + _ => Err(NotEof), + } + } + + pub fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + let kind = match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n" as &'static [u8]); + BufKind::Chunked(buf) + } + Kind::Length(ref mut remaining) => { + trace!("sized write, len = {}", len); + if len as u64 > *remaining { + let limit = *remaining as usize; + *remaining = 0; + BufKind::Limited(msg.take(limit)) + } else { + *remaining -= len as u64; + BufKind::Exact(msg) + } + } + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + BufKind::Exact(msg) + } + }; + EncodedBuf { kind } + } + + pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool + where + B: Buf, + { + let len = msg.remaining(); + debug_assert!(len > 0, "encode() called with empty buf"); + + match self.kind { + Kind::Chunked => { + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + !self.is_last + } + Kind::Length(remaining) => { + use std::cmp::Ordering; + + trace!("sized write, len = {}", len); + match (len as u64).cmp(&remaining) { + Ordering::Equal => { + dst.buffer(msg); + !self.is_last + } + Ordering::Greater => { + dst.buffer(msg.take(remaining as usize)); + !self.is_last + } + Ordering::Less => { + dst.buffer(msg); + false + } + } + } + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); + dst.buffer(msg); + false + } + } + } + + /// Encodes the full body, without verifying the remaining length matches. + /// + /// This is used in conjunction with Payload::__hyper_full_data(), which + /// means we can trust that the buf has the correct size (the buf itself + /// was checked to make the headers). + pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) + where + B: Buf, + { + debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); + debug_assert!( + match self.kind { + Kind::Length(len) => len == msg.remaining() as u64, + _ => true, + }, + "danger_full_buf length mismatches" + ); + + match self.kind { + Kind::Chunked => { + let len = msg.remaining(); + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(b"\r\n0\r\n\r\n" as &'static [u8]); + dst.buffer(buf); + } + _ => { + dst.buffer(msg); + } + } + } +} + +impl<B> Buf for EncodedBuf<B> +where + B: Buf, +{ + #[inline] + fn remaining(&self) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.remaining(), + BufKind::Limited(ref b) => b.remaining(), + BufKind::Chunked(ref b) => b.remaining(), + BufKind::ChunkedEnd(ref b) => b.remaining(), + } + } + + #[inline] + fn bytes(&self) -> &[u8] { + match self.kind { + BufKind::Exact(ref b) => b.bytes(), + BufKind::Limited(ref b) => b.bytes(), + BufKind::Chunked(ref b) => b.bytes(), + BufKind::ChunkedEnd(ref b) => b.bytes(), + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match self.kind { + BufKind::Exact(ref mut b) => b.advance(cnt), + BufKind::Limited(ref mut b) => b.advance(cnt), + BufKind::Chunked(ref mut b) => b.advance(cnt), + BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + match self.kind { + BufKind::Exact(ref b) => b.bytes_vectored(dst), + BufKind::Limited(ref b) => b.bytes_vectored(dst), + BufKind::Chunked(ref b) => b.bytes_vectored(dst), + BufKind::ChunkedEnd(ref b) => b.bytes_vectored(dst), + } + } +} + +#[cfg(target_pointer_width = "32")] +const USIZE_BYTES: usize = 4; + +#[cfg(target_pointer_width = "64")] +const USIZE_BYTES: usize = 8; + +// each byte will become 2 hex +const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; + +#[derive(Clone, Copy)] +struct ChunkSize { + bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], + pos: u8, + len: u8, +} + +impl ChunkSize { + fn new(len: usize) -> ChunkSize { + use std::fmt::Write; + let mut size = ChunkSize { + bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], + pos: 0, + len: 0, + }; + write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); + size + } +} + +impl Buf for ChunkSize { + #[inline] + fn remaining(&self) -> usize { + (self.len - self.pos).into() + } + + #[inline] + fn bytes(&self) -> &[u8] { + &self.bytes[self.pos.into()..self.len.into()] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining()); + self.pos += cnt as u8; // just asserted cnt fits in u8 + } +} + +impl fmt::Debug for ChunkSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkSize") + .field("bytes", &&self.bytes[..self.len.into()]) + .field("pos", &self.pos) + .finish() + } +} + +impl fmt::Write for ChunkSize { + fn write_str(&mut self, num: &str) -> fmt::Result { + use std::io::Write; + (&mut self.bytes[self.len.into()..]) + .write_all(num.as_bytes()) + .expect("&mut [u8].write() cannot error"); + self.len += num.len() as u8; // safe because bytes is never bigger than 256 + Ok(()) + } +} + +impl<B: Buf> From<B> for EncodedBuf<B> { + fn from(buf: B) -> Self { + EncodedBuf { + kind: BufKind::Exact(buf), + } + } +} + +impl<B: Buf> From<Take<B>> for EncodedBuf<B> { + fn from(buf: Take<B>) -> Self { + EncodedBuf { + kind: BufKind::Limited(buf), + } + } +} + +impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { + fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { + EncodedBuf { + kind: BufKind::Chunked(buf), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::BufMut; + + use super::super::io::Cursor; + use super::Encoder; + + #[test] + fn chunked() { + let mut encoder = Encoder::chunked(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + assert_eq!(dst, b"7\r\nfoo bar\r\n"); + + let msg2 = b"baz quux herp".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); + + let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); + dst.put(end); + + assert_eq!( + dst, + b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() + ); + } + + #[test] + fn length() { + let max_len = 8; + let mut encoder = Encoder::length(max_len as u64); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst.len(), max_len); + assert_eq!(dst, b"foo barb"); + assert!(encoder.is_eof()); + assert!(encoder.end::<()>().unwrap().is_none()); + } + + #[test] + fn eof() { + let mut encoder = Encoder::close_delimited(); + let mut dst = Vec::new(); + + let msg1 = b"foo bar".as_ref(); + let buf1 = encoder.encode(msg1); + dst.put(buf1); + + assert_eq!(dst, b"foo bar"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + + let msg2 = b"baz".as_ref(); + let buf2 = encoder.encode(msg2); + dst.put(buf2); + + assert_eq!(dst, b"foo barbaz"); + assert!(!encoder.is_eof()); + encoder.end::<()>().unwrap_err(); + } +} diff --git a/third_party/rust/hyper/src/proto/h1/io.rs b/third_party/rust/hyper/src/proto/h1/io.rs new file mode 100644 index 0000000000..00f4f64f47 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/io.rs @@ -0,0 +1,907 @@ +use std::cell::Cell; +use std::cmp; +use std::fmt; +use std::io::{self, IoSlice}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{Http1Transaction, ParseContext, ParsedMessage}; +use crate::common::buf::BufList; +use crate::common::{task, Pin, Poll, Unpin}; + +/// The initial buffer size allocated before trying to read from IO. +pub(crate) const INIT_BUFFER_SIZE: usize = 8192; + +/// The minimum value that can be set to max buffer size. +pub const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; + +/// The default maximum read buffer size. If the buffer gets this big and +/// a message is still not complete, a `TooLarge` error is triggered. +// Note: if this changes, update server::conn::Http::max_buf_size docs. +pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +/// The maximum number of distinct `Buf`s to hold in a list before requiring +/// a flush. Only affects when the buffer strategy is to queue buffers. +/// +/// Note that a flush can happen before reaching the maximum. This simply +/// forces a flush if the queue gets this big. +const MAX_BUF_LIST_BUFFERS: usize = 16; + +pub struct Buffered<T, B> { + flush_pipeline: bool, + io: T, + read_blocked: bool, + read_buf: BytesMut, + read_buf_strategy: ReadStrategy, + write_buf: WriteBuf<B>, +} + +impl<T, B> fmt::Debug for Buffered<T, B> +where + B: Buf, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Buffered") + .field("read_buf", &self.read_buf) + .field("write_buf", &self.write_buf) + .finish() + } +} + +impl<T, B> Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + pub fn new(io: T) -> Buffered<T, B> { + Buffered { + flush_pipeline: false, + io, + read_blocked: false, + read_buf: BytesMut::with_capacity(0), + read_buf_strategy: ReadStrategy::default(), + write_buf: WriteBuf::new(), + } + } + + pub fn set_flush_pipeline(&mut self, enabled: bool) { + debug_assert!(!self.write_buf.has_remaining()); + self.flush_pipeline = enabled; + if enabled { + self.set_write_strategy_flatten(); + } + } + + pub fn set_max_buf_size(&mut self, max: usize) { + assert!( + max >= MINIMUM_MAX_BUFFER_SIZE, + "The max_buf_size cannot be smaller than {}.", + MINIMUM_MAX_BUFFER_SIZE, + ); + self.read_buf_strategy = ReadStrategy::with_max(max); + self.write_buf.max_buf_size = max; + } + + pub fn set_read_buf_exact_size(&mut self, sz: usize) { + self.read_buf_strategy = ReadStrategy::Exact(sz); + } + + pub fn set_write_strategy_flatten(&mut self) { + // this should always be called only at construction time, + // so this assert is here to catch myself + debug_assert!(self.write_buf.queue.bufs_cnt() == 0); + self.write_buf.set_strategy(WriteStrategy::Flatten); + } + + pub fn read_buf(&self) -> &[u8] { + self.read_buf.as_ref() + } + + #[cfg(test)] + #[cfg(feature = "nightly")] + pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + + /// Return the "allocated" available space, not the potential space + /// that could be allocated in the future. + fn read_buf_remaining_mut(&self) -> usize { + self.read_buf.capacity() - self.read_buf.len() + } + + pub fn headers_buf(&mut self) -> &mut Vec<u8> { + let buf = self.write_buf.headers_mut(); + &mut buf.bytes + } + + pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { + &mut self.write_buf + } + + pub fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { + self.write_buf.buffer(buf) + } + + pub fn can_buffer(&self) -> bool { + self.flush_pipeline || self.write_buf.can_buffer() + } + + pub fn consume_leading_lines(&mut self) { + if !self.read_buf.is_empty() { + let mut i = 0; + while i < self.read_buf.len() { + match self.read_buf[i] { + b'\r' | b'\n' => i += 1, + _ => break, + } + } + self.read_buf.advance(i); + } + } + + pub(super) fn parse<S>( + &mut self, + cx: &mut task::Context<'_>, + parse_ctx: ParseContext<'_>, + ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> + where + S: Http1Transaction, + { + loop { + match S::parse( + &mut self.read_buf, + ParseContext { + cached_headers: parse_ctx.cached_headers, + req_method: parse_ctx.req_method, + }, + )? { + Some(msg) => { + debug!("parsed {} headers", msg.head.headers.len()); + return Poll::Ready(Ok(msg)); + } + None => { + let max = self.read_buf_strategy.max(); + if self.read_buf.len() >= max { + debug!("max_buf_size ({}) reached, closing", max); + return Poll::Ready(Err(crate::Error::new_too_large())); + } + } + } + if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { + trace!("parse eof"); + return Poll::Ready(Err(crate::Error::new_incomplete())); + } + } + } + + pub fn poll_read_from_io(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> { + self.read_blocked = false; + let next = self.read_buf_strategy.next(); + if self.read_buf_remaining_mut() < next { + self.read_buf.reserve(next); + } + match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) { + Poll::Ready(Ok(n)) => { + debug!("read {} bytes", n); + self.read_buf_strategy.record(n); + Poll::Ready(Ok(n)) + } + Poll::Pending => { + self.read_blocked = true; + Poll::Pending + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + + pub fn into_inner(self) -> (T, Bytes) { + (self.io, self.read_buf.freeze()) + } + + pub fn io_mut(&mut self) -> &mut T { + &mut self.io + } + + pub fn is_read_blocked(&self) -> bool { + self.read_blocked + } + + pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + if self.flush_pipeline && !self.read_buf.is_empty() { + Poll::Ready(Ok(())) + } else if self.write_buf.remaining() == 0 { + Pin::new(&mut self.io).poll_flush(cx) + } else { + if let WriteStrategy::Flatten = self.write_buf.strategy { + return self.poll_flush_flattened(cx); + } + loop { + let n = + ready!(Pin::new(&mut self.io).poll_write_buf(cx, &mut self.write_buf.auto()))?; + debug!("flushed {} bytes", n); + if self.write_buf.remaining() == 0 { + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + } + + /// Specialized version of `flush` when strategy is Flatten. + /// + /// Since all buffered bytes are flattened into the single headers buffer, + /// that skips some bookkeeping around using multiple buffers. + fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + loop { + let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.bytes()))?; + debug!("flushed {} bytes", n); + self.write_buf.headers.advance(n); + if self.write_buf.headers.remaining() == 0 { + self.write_buf.headers.reset(); + break; + } else if n == 0 { + trace!( + "write returned zero, but {} bytes remaining", + self.write_buf.remaining() + ); + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + Pin::new(&mut self.io).poll_flush(cx) + } + + #[cfg(test)] + fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { + futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) + } +} + +// The `B` is a `Buf`, we never project a pin to it +impl<T: Unpin, B> Unpin for Buffered<T, B> {} + +// TODO: This trait is old... at least rename to PollBytes or something... +pub trait MemRead { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; +} + +impl<T, B> MemRead for Buffered<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { + if !self.read_buf.is_empty() { + let n = std::cmp::min(len, self.read_buf.len()); + Poll::Ready(Ok(self.read_buf.split_to(n).freeze())) + } else { + let n = ready!(self.poll_read_from_io(cx))?; + Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze())) + } + } +} + +#[derive(Clone, Copy, Debug)] +enum ReadStrategy { + Adaptive { + decrease_now: bool, + next: usize, + max: usize, + }, + Exact(usize), +} + +impl ReadStrategy { + fn with_max(max: usize) -> ReadStrategy { + ReadStrategy::Adaptive { + decrease_now: false, + next: INIT_BUFFER_SIZE, + max, + } + } + + fn next(&self) -> usize { + match *self { + ReadStrategy::Adaptive { next, .. } => next, + ReadStrategy::Exact(exact) => exact, + } + } + + fn max(&self) -> usize { + match *self { + ReadStrategy::Adaptive { max, .. } => max, + ReadStrategy::Exact(exact) => exact, + } + } + + fn record(&mut self, bytes_read: usize) { + if let ReadStrategy::Adaptive { + ref mut decrease_now, + ref mut next, + max, + .. + } = *self + { + if bytes_read >= *next { + *next = cmp::min(incr_power_of_two(*next), max); + *decrease_now = false; + } else { + let decr_to = prev_power_of_two(*next); + if bytes_read < decr_to { + if *decrease_now { + *next = cmp::max(decr_to, INIT_BUFFER_SIZE); + *decrease_now = false; + } else { + // Decreasing is a two "record" process. + *decrease_now = true; + } + } else { + // A read within the current range should cancel + // a potential decrease, since we just saw proof + // that we still need this size. + *decrease_now = false; + } + } + } + } +} + +fn incr_power_of_two(n: usize) -> usize { + n.saturating_mul(2) +} + +fn prev_power_of_two(n: usize) -> usize { + // Only way this shift can underflow is if n is less than 4. + // (Which would means `usize::MAX >> 64` and underflowed!) + debug_assert!(n >= 4); + (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 +} + +impl Default for ReadStrategy { + fn default() -> ReadStrategy { + ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) + } +} + +#[derive(Clone)] +pub struct Cursor<T> { + bytes: T, + pos: usize, +} + +impl<T: AsRef<[u8]>> Cursor<T> { + #[inline] + pub(crate) fn new(bytes: T) -> Cursor<T> { + Cursor { bytes, pos: 0 } + } +} + +impl Cursor<Vec<u8>> { + fn reset(&mut self) { + self.pos = 0; + self.bytes.clear(); + } +} + +impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Cursor") + .field("pos", &self.pos) + .field("len", &self.bytes.as_ref().len()) + .finish() + } +} + +impl<T: AsRef<[u8]>> Buf for Cursor<T> { + #[inline] + fn remaining(&self) -> usize { + self.bytes.as_ref().len() - self.pos + } + + #[inline] + fn bytes(&self) -> &[u8] { + &self.bytes.as_ref()[self.pos..] + } + + #[inline] + fn advance(&mut self, cnt: usize) { + debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); + self.pos += cnt; + } +} + +// an internal buffer to collect writes before flushes +pub(super) struct WriteBuf<B> { + /// Re-usable buffer that holds message headers + headers: Cursor<Vec<u8>>, + max_buf_size: usize, + /// Deque of user buffers if strategy is Queue + queue: BufList<B>, + strategy: WriteStrategy, +} + +impl<B: Buf> WriteBuf<B> { + fn new() -> WriteBuf<B> { + WriteBuf { + headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, + queue: BufList::new(), + strategy: WriteStrategy::Auto, + } + } +} + +impl<B> WriteBuf<B> +where + B: Buf, +{ + fn set_strategy(&mut self, strategy: WriteStrategy) { + self.strategy = strategy; + } + + #[inline] + fn auto(&mut self) -> WriteBufAuto<'_, B> { + WriteBufAuto::new(self) + } + + pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { + debug_assert!(buf.has_remaining()); + match self.strategy { + WriteStrategy::Flatten => { + let head = self.headers_mut(); + //perf: This is a little faster than <Vec as BufMut>>::put, + //but accomplishes the same result. + loop { + let adv = { + let slice = buf.bytes(); + if slice.is_empty() { + return; + } + head.bytes.extend_from_slice(slice); + slice.len() + }; + buf.advance(adv); + } + } + WriteStrategy::Auto | WriteStrategy::Queue => { + self.queue.push(buf.into()); + } + } + } + + fn can_buffer(&self) -> bool { + match self.strategy { + WriteStrategy::Flatten => self.remaining() < self.max_buf_size, + WriteStrategy::Auto | WriteStrategy::Queue => { + self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size + } + } + } + + fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { + debug_assert!(!self.queue.has_remaining()); + &mut self.headers + } +} + +impl<B: Buf> fmt::Debug for WriteBuf<B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WriteBuf") + .field("remaining", &self.remaining()) + .field("strategy", &self.strategy) + .finish() + } +} + +impl<B: Buf> Buf for WriteBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.headers.remaining() + self.queue.remaining() + } + + #[inline] + fn bytes(&self) -> &[u8] { + let headers = self.headers.bytes(); + if !headers.is_empty() { + headers + } else { + self.queue.bytes() + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + let hrem = self.headers.remaining(); + + match hrem.cmp(&cnt) { + cmp::Ordering::Equal => self.headers.reset(), + cmp::Ordering::Greater => self.headers.advance(cnt), + cmp::Ordering::Less => { + let qcnt = cnt - hrem; + self.headers.reset(); + self.queue.advance(qcnt); + } + } + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + let n = self.headers.bytes_vectored(dst); + self.queue.bytes_vectored(&mut dst[n..]) + n + } +} + +/// Detects when wrapped `WriteBuf` is used for vectored IO, and +/// adjusts the `WriteBuf` strategy if not. +struct WriteBufAuto<'a, B: Buf> { + bytes_called: Cell<bool>, + bytes_vec_called: Cell<bool>, + inner: &'a mut WriteBuf<B>, +} + +impl<'a, B: Buf> WriteBufAuto<'a, B> { + fn new(inner: &'a mut WriteBuf<B>) -> WriteBufAuto<'a, B> { + WriteBufAuto { + bytes_called: Cell::new(false), + bytes_vec_called: Cell::new(false), + inner, + } + } +} + +impl<'a, B: Buf> Buf for WriteBufAuto<'a, B> { + #[inline] + fn remaining(&self) -> usize { + self.inner.remaining() + } + + #[inline] + fn bytes(&self) -> &[u8] { + self.bytes_called.set(true); + self.inner.bytes() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.inner.advance(cnt) + } + + #[inline] + fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + self.bytes_vec_called.set(true); + self.inner.bytes_vectored(dst) + } +} + +impl<'a, B: Buf + 'a> Drop for WriteBufAuto<'a, B> { + fn drop(&mut self) { + if let WriteStrategy::Auto = self.inner.strategy { + if self.bytes_vec_called.get() { + self.inner.strategy = WriteStrategy::Queue; + } else if self.bytes_called.get() { + trace!("detected no usage of vectored write, flattening"); + self.inner.strategy = WriteStrategy::Flatten; + self.inner.headers.bytes.put(&mut self.inner.queue); + } + } + } +} + +#[derive(Debug)] +enum WriteStrategy { + Auto, + Flatten, + Queue, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use tokio_test::io::Builder as Mock; + + #[cfg(feature = "nightly")] + use test::Bencher; + + /* + impl<T: Read> MemRead for AsyncIo<T> { + fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { + let mut v = vec![0; len]; + let n = try_nb!(self.read(v.as_mut_slice())); + Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + } + } + */ + + #[tokio::test] + async fn iobuf_write_empty_slice() { + // First, let's just check that the Mock would normally return an + // error on an unexpected write, even if the buffer is empty... + let mut mock = Mock::new().build(); + futures_util::future::poll_fn(|cx| { + Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) + }) + .await + .expect_err("should be a broken pipe"); + + // underlying io will return the logic error upon write, + // so we are testing that the io_buf does not trigger a write + // when there is nothing to flush + let mock = Mock::new().build(); + let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + io_buf.flush().await.expect("should short-circuit flush"); + } + + #[tokio::test] + async fn parse_reads_until_blocked() { + use crate::proto::h1::ClientTransaction; + + let mock = Mock::new() + // Split over multiple reads will read all of it + .read(b"HTTP/1.1 200 OK\r\n") + .read(b"Server: hyper\r\n") + // missing last line ending + .wait(Duration::from_secs(1)) + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // We expect a `parse` to be not ready, and so can't await it directly. + // Rather, this `poll_fn` will wrap the `Poll` result. + futures_util::future::poll_fn(|cx| { + let parse_ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + assert!(buffered + .parse::<ClientTransaction>(cx, parse_ctx) + .is_pending()); + Poll::Ready(()) + }) + .await; + + assert_eq!( + buffered.read_buf, + b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..] + ); + } + + #[test] + fn read_strategy_adaptive_increments() { + let mut strategy = ReadStrategy::default(); + assert_eq!(strategy.next(), 8192); + + // Grows if record == next + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(16384); + assert_eq!(strategy.next(), 32768); + + // Enormous records still increment at same rate + strategy.record(::std::usize::MAX); + assert_eq!(strategy.next(), 65536); + + let max = strategy.max(); + while strategy.next() < max { + strategy.record(max); + } + + assert_eq!(strategy.next(), max, "never goes over max"); + strategy.record(max + 1); + assert_eq!(strategy.next(), max, "never goes over max"); + } + + #[test] + fn read_strategy_adaptive_decrements() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + strategy.record(8192); + assert_eq!(strategy.next(), 16384, "record was with range"); + + strategy.record(1); + assert_eq!( + strategy.next(), + 16384, + "in-range record should make this the 'first' again" + ); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "second smaller record decrements"); + + strategy.record(1); + assert_eq!(strategy.next(), 8192, "first doesn't decrement"); + strategy.record(1); + assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum"); + } + + #[test] + fn read_strategy_adaptive_stays_the_same() { + let mut strategy = ReadStrategy::default(); + strategy.record(8192); + assert_eq!(strategy.next(), 16384); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "first smaller record doesn't decrement yet" + ); + + strategy.record(8193); + assert_eq!( + strategy.next(), + 16384, + "with current step does not decrement" + ); + } + + #[test] + fn read_strategy_adaptive_max_fuzz() { + fn fuzz(max: usize) { + let mut strategy = ReadStrategy::with_max(max); + while strategy.next() < max { + strategy.record(::std::usize::MAX); + } + let mut next = strategy.next(); + while next > 8192 { + strategy.record(1); + strategy.record(1); + next = strategy.next(); + assert!( + next.is_power_of_two(), + "decrement should be powers of two: {} (max = {})", + next, + max, + ); + } + } + + let mut max = 8192; + while max < std::usize::MAX { + fuzz(max); + max = (max / 2).saturating_mul(3); + } + fuzz(::std::usize::MAX); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] // needs to trigger a debug_assert + fn write_buf_requires_non_empty_bufs() { + let mock = Mock::new().build(); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + buffered.buffer(Cursor::new(Vec::new())); + } + + /* + TODO: needs tokio_test::io to allow configure write_buf calls + #[test] + fn write_buf_queue() { + let _ = pretty_env_logger::try_init(); + + let mock = AsyncIo::new_buf(vec![], 1024); + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + buffered.flush().unwrap(); + + assert_eq!(buffered.io, b"hello world, it's hyper!"); + assert_eq!(buffered.io.num_writes(), 1); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + */ + + #[tokio::test] + async fn write_buf_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + // Just a single write + .write(b"hello world, it's hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Flatten); + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + + buffered.flush().await.expect("flush"); + } + + #[tokio::test] + async fn write_buf_auto_flatten() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + // Expects write_buf to only consume first buffer + .write(b"hello ") + // And then the Auto strategy will have flattened + .write(b"world, it's hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + + // we have 4 buffers, but hope to detect that vectored IO isn't + // being used, and switch to flattening automatically, + // resulting in only 2 writes + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + #[tokio::test] + async fn write_buf_queue_disable_auto() { + let _ = pretty_env_logger::try_init(); + + let mock = Mock::new() + .write(b"hello ") + .write(b"world, ") + .write(b"it's ") + .write(b"hyper!") + .build(); + + let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); + buffered.write_buf.set_strategy(WriteStrategy::Queue); + + // we have 4 buffers, and vec IO disabled, but explicitly said + // don't try to auto detect (via setting strategy above) + + buffered.headers_buf().extend(b"hello "); + buffered.buffer(Cursor::new(b"world, ".to_vec())); + buffered.buffer(Cursor::new(b"it's ".to_vec())); + buffered.buffer(Cursor::new(b"hyper!".to_vec())); + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); + + buffered.flush().await.expect("flush"); + + assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { + let s = "Hello, World!"; + b.bytes = s.len() as u64; + + let mut write_buf = WriteBuf::<bytes::Bytes>::new(); + write_buf.set_strategy(WriteStrategy::Flatten); + b.iter(|| { + let chunk = bytes::Bytes::from(s); + write_buf.buffer(chunk); + ::test::black_box(&write_buf); + write_buf.headers.bytes.clear(); + }) + } +} diff --git a/third_party/rust/hyper/src/proto/h1/mod.rs b/third_party/rust/hyper/src/proto/h1/mod.rs new file mode 100644 index 0000000000..2d0bf39bc9 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/mod.rs @@ -0,0 +1,95 @@ +use bytes::BytesMut; +use http::{HeaderMap, Method}; + +use crate::proto::{BodyLength, DecodedLength, MessageHead}; + +pub(crate) use self::conn::Conn; +pub use self::decode::Decoder; +pub(crate) use self::dispatch::Dispatcher; +pub use self::encode::{EncodedBuf, Encoder}; +pub use self::io::Cursor; //TODO: move out of h1::io +pub use self::io::MINIMUM_MAX_BUFFER_SIZE; + +mod conn; +pub(super) mod date; +mod decode; +pub(crate) mod dispatch; +mod encode; +mod io; +mod role; + +pub(crate) type ServerTransaction = role::Server; +pub(crate) type ClientTransaction = role::Client; + +pub(crate) trait Http1Transaction { + type Incoming; + type Outgoing: Default; + const LOG: &'static str; + fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>; + fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>; + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>; + + fn is_client() -> bool { + !Self::is_server() + } + + fn is_server() -> bool { + !Self::is_client() + } + + fn should_error_on_parse_eof() -> bool { + Self::is_client() + } + + fn should_read_first() -> bool { + Self::is_server() + } + + fn update_date() {} +} + +/// Result newtype for Http1Transaction::parse. +pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>; + +#[derive(Debug)] +pub(crate) struct ParsedMessage<T> { + head: MessageHead<T>, + decode: DecodedLength, + expect_continue: bool, + keep_alive: bool, + wants_upgrade: bool, +} + +pub(crate) struct ParseContext<'a> { + cached_headers: &'a mut Option<HeaderMap>, + req_method: &'a mut Option<Method>, +} + +/// Passed to Http1Transaction::encode +pub(crate) struct Encode<'a, T> { + head: &'a mut MessageHead<T>, + body: Option<BodyLength>, + keep_alive: bool, + req_method: &'a mut Option<Method>, + title_case_headers: bool, +} + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/third_party/rust/hyper/src/proto/h1/role.rs b/third_party/rust/hyper/src/proto/h1/role.rs new file mode 100644 index 0000000000..e99f4cf541 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h1/role.rs @@ -0,0 +1,1835 @@ +// `mem::uninitialized` replaced with `mem::MaybeUninit`, +// can't upgrade yet +#![allow(deprecated)] + +use std::fmt::{self, Write}; +use std::mem; + +use bytes::BytesMut; +use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::{HeaderMap, Method, StatusCode, Version}; + +use crate::error::Parse; +use crate::headers; +use crate::proto::h1::{ + date, Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, +}; +use crate::proto::{BodyLength, DecodedLength, MessageHead, RequestHead, RequestLine}; + +const MAX_HEADERS: usize = 100; +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific + +macro_rules! header_name { + ($bytes:expr) => {{ + #[cfg(debug_assertions)] + { + match HeaderName::from_bytes($bytes) { + Ok(name) => name, + Err(_) => panic!( + "illegal header name from httparse: {:?}", + ::bytes::Bytes::copy_from_slice($bytes) + ), + } + } + + #[cfg(not(debug_assertions))] + { + HeaderName::from_bytes($bytes).expect("header name validated by httparse") + } + }}; +} + +macro_rules! header_value { + ($bytes:expr) => {{ + #[cfg(debug_assertions)] + { + let __hvb: ::bytes::Bytes = $bytes; + match HeaderValue::from_maybe_shared(__hvb.clone()) { + Ok(name) => name, + Err(_) => panic!("illegal header value from httparse: {:?}", __hvb), + } + } + + #[cfg(not(debug_assertions))] + { + // Unsafe: httparse already validated header value + unsafe { HeaderValue::from_maybe_shared_unchecked($bytes) } + } + }}; +} + +// There are 2 main roles, Client and Server. + +pub(crate) enum Client {} + +pub(crate) enum Server {} + +impl Http1Transaction for Server { + type Incoming = RequestLine; + type Outgoing = StatusCode; + const LOG: &'static str = "{role=server}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<RequestLine> { + if buf.is_empty() { + return Ok(None); + } + + let mut keep_alive; + let is_http_11; + let subject; + let version; + let len; + let headers_len; + + // Unsafe: both headers_indices and headers are using uninitialized memory, + // but we *never* read any of it until after httparse has assigned + // values into it. By not zeroing out the stack memory, this saves + // a good ~5% on pipeline benchmarks. + let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; + { + let mut headers: [httparse::Header<'_>; MAX_HEADERS] = unsafe { mem::uninitialized() }; + trace!( + "Request.parse([Header; {}], [u8; {}])", + headers.len(), + buf.len() + ); + let mut req = httparse::Request::new(&mut headers); + let bytes = buf.as_ref(); + match req.parse(bytes) { + Ok(httparse::Status::Complete(parsed_len)) => { + trace!("Request.parse Complete({})", parsed_len); + len = parsed_len; + subject = RequestLine( + Method::from_bytes(req.method.unwrap().as_bytes())?, + req.path.unwrap().parse()?, + ); + version = if req.version.unwrap() == 1 { + keep_alive = true; + is_http_11 = true; + Version::HTTP_11 + } else { + keep_alive = false; + is_http_11 = false; + Version::HTTP_10 + }; + + record_header_indices(bytes, &req.headers, &mut headers_indices)?; + headers_len = req.headers.len(); + } + Ok(httparse::Status::Partial) => return Ok(None), + Err(err) => { + return Err(match err { + // if invalid Token, try to determine if for method or path + httparse::Error::Token => { + if req.method.is_none() { + Parse::Method + } else { + debug_assert!(req.path.is_none()); + Parse::Uri + } + } + other => other.into(), + }); + } + } + }; + + let slice = buf.split_to(len).freeze(); + + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. (irrelevant to Request) + // 2. (irrelevant to Request) + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Length 0. + // 7. (irrelevant to Request) + + let mut decoder = DecodedLength::ZERO; + let mut expect_continue = false; + let mut con_len = None; + let mut is_te = false; + let mut is_te_chunked = false; + let mut wants_upgrade = subject.0 == Method::CONNECT; + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + + for header in &headers_indices[..headers_len] { + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + match name { + header::TRANSFER_ENCODING => { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if !is_http_11 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + return Err(Parse::Header); + } + is_te = true; + if headers::is_chunked_(&value) { + is_te_chunked = true; + decoder = DecodedLength::CHUNKED; + } + } + header::CONTENT_LENGTH => { + if is_te { + continue; + } + let len = value + .to_str() + .map_err(|_| Parse::Header) + .and_then(|s| s.parse().map_err(|_| Parse::Header))?; + if let Some(prev) = con_len { + if prev != len { + debug!( + "multiple Content-Length headers with different values: [{}, {}]", + prev, len, + ); + return Err(Parse::Header); + } + // we don't need to append this secondary length + continue; + } + decoder = DecodedLength::checked_new(len)?; + con_len = Some(len); + } + header::CONNECTION => { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + header::EXPECT => { + expect_continue = value.as_bytes() == b"100-continue"; + } + header::UPGRADE => { + // Upgrades are only allowed with HTTP/1.1 + wants_upgrade = is_http_11; + } + + _ => (), + } + + headers.append(name, value); + } + + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::Header); + } + + *ctx.req_method = Some(subject.0.clone()); + + Ok(Some(ParsedMessage { + head: MessageHead { + version, + subject, + headers, + }, + decode: decoder, + expect_continue, + keep_alive, + wants_upgrade, + })) + } + + fn encode( + mut msg: Encode<'_, Self::Outgoing>, + mut dst: &mut Vec<u8>, + ) -> crate::Result<Encoder> { + trace!( + "Server::encode status={:?}, body={:?}, req_method={:?}", + msg.head.subject, + msg.body, + msg.req_method + ); + debug_assert!( + !msg.title_case_headers, + "no server config for title case headers" + ); + + let mut wrote_len = false; + + // hyper currently doesn't support returning 1xx status codes as a Response + // This is because Service only allows returning a single Response, and + // so if you try to reply with a e.g. 100 Continue, you have no way of + // replying with the latter status code response. + let (ret, mut is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + (Ok(()), true) + } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() { + // Sending content-length or transfer-encoding header on 2xx response + // to CONNECT is forbidden in RFC 7231. + wrote_len = true; + (Ok(()), true) + } else if msg.head.subject.is_informational() { + warn!("response with 1xx status code not supported"); + *msg.head = MessageHead::default(); + msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; + (Err(crate::Error::new_user_unsupported_status_code()), true) + } else { + (Ok(()), !msg.keep_alive) + }; + + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let orig_len = dst.len(); + let rewind = |dst: &mut Vec<u8>| { + dst.truncate(orig_len); + }; + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + if msg.head.version == Version::HTTP_11 && msg.head.subject == StatusCode::OK { + extend(dst, b"HTTP/1.1 200 OK\r\n"); + } else { + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0 "), + Version::HTTP_11 => extend(dst, b"HTTP/1.1 "), + Version::HTTP_2 => { + warn!("response with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1 "); + } + other => panic!("unexpected response version: {:?}", other), + } + + extend(dst, msg.head.subject.as_str().as_bytes()); + extend(dst, b" "); + // a reason MUST be written, as many parsers will expect it. + extend( + dst, + msg.head + .subject + .canonical_reason() + .unwrap_or("<none>") + .as_bytes(), + ); + extend(dst, b"\r\n"); + } + + let mut encoder = Encoder::length(0); + let mut wrote_date = false; + let mut cur_name = None; + let mut is_name_written = false; + let mut must_write_chunked = false; + let mut prev_con_len = None; + + macro_rules! handle_is_name_written { + () => {{ + if is_name_written { + // we need to clean up and write the newline + debug_assert_ne!( + &dst[dst.len() - 2..], + b"\r\n", + "previous header wrote newline but set is_name_written" + ); + + if must_write_chunked { + extend(dst, b", chunked\r\n"); + } else { + extend(dst, b"\r\n"); + } + } + }}; + } + + 'headers: for (opt_name, value) in msg.head.headers.drain() { + if let Some(n) = opt_name { + cur_name = Some(n); + handle_is_name_written!(); + is_name_written = false; + } + let name = cur_name.as_ref().expect("current header name"); + match *name { + header::CONTENT_LENGTH => { + if wrote_len && !is_name_written { + warn!("unexpected content-length found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + match msg.body { + Some(BodyLength::Known(known_len)) => { + // The Payload claims to know a length, and + // the headers are already set. For performance + // reasons, we are just going to trust that + // the values match. + // + // In debug builds, we'll assert they are the + // same to help developers find bugs. + #[cfg(debug_assertions)] + { + if let Some(len) = headers::content_length_parse(&value) { + assert!( + len == known_len, + "payload claims content-length of {}, custom content-length header claims {}", + known_len, + len, + ); + } + } + + if !is_name_written { + encoder = Encoder::length(known_len); + extend(dst, b"content-length: "); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + } + continue 'headers; + } + Some(BodyLength::Unknown) => { + // The Payload impl didn't know how long the + // body is, but a length header was included. + // We have to parse the value to return our + // Encoder... + + if let Some(len) = headers::content_length_parse(&value) { + if let Some(prev) = prev_con_len { + if prev != len { + warn!( + "multiple Content-Length values found: [{}, {}]", + prev, len + ); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + debug_assert!(is_name_written); + continue 'headers; + } else { + // we haven't written content-length yet! + encoder = Encoder::length(len); + extend(dst, b"content-length: "); + extend(dst, value.as_bytes()); + wrote_len = true; + is_name_written = true; + prev_con_len = Some(len); + continue 'headers; + } + } else { + warn!("illegal Content-Length value: {:?}", value); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + } + None => { + // We have no body to actually send, + // but the headers claim a content-length. + // There's only 2 ways this makes sense: + // + // - The header says the length is `0`. + // - This is a response to a `HEAD` request. + if msg.req_method == &Some(Method::HEAD) { + debug_assert_eq!(encoder, Encoder::length(0)); + } else { + if value.as_bytes() != b"0" { + warn!( + "content-length value found, but empty body provided: {:?}", + value + ); + } + continue 'headers; + } + } + } + wrote_len = true; + } + header::TRANSFER_ENCODING => { + if wrote_len && !is_name_written { + warn!("unexpected transfer-encoding found, canceling"); + rewind(dst); + return Err(crate::Error::new_user_header()); + } + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + wrote_len = true; + // Must check each value, because `chunked` needs to be the + // last encoding, or else we add it. + must_write_chunked = !headers::is_chunked_(&value); + + if !is_name_written { + encoder = Encoder::chunked(); + is_name_written = true; + extend(dst, b"transfer-encoding: "); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::CONNECTION => { + if !is_last && headers::connection_close(&value) { + is_last = true; + } + if !is_name_written { + is_name_written = true; + extend(dst, b"connection: "); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + continue 'headers; + } + header::DATE => { + wrote_date = true; + } + _ => (), + } + //TODO: this should perhaps instead combine them into + //single lines, as RFC7230 suggests is preferable. + + // non-special write Name and Value + debug_assert!( + !is_name_written, + "{:?} set is_name_written and didn't continue loop", + name, + ); + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + + handle_is_name_written!(); + + if !wrote_len { + encoder = match msg.body { + Some(BodyLength::Unknown) => { + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + Encoder::close_delimited() + } else { + extend(dst, b"transfer-encoding: chunked\r\n"); + Encoder::chunked() + } + } + None | Some(BodyLength::Known(0)) => { + if msg.head.subject != StatusCode::NOT_MODIFIED { + extend(dst, b"content-length: 0\r\n"); + } + Encoder::length(0) + } + Some(BodyLength::Known(len)) => { + if msg.head.subject == StatusCode::NOT_MODIFIED { + Encoder::length(0) + } else { + extend(dst, b"content-length: "); + let _ = ::itoa::write(&mut dst, len); + extend(dst, b"\r\n"); + Encoder::length(len) + } + } + }; + } + + if !Server::can_have_body(msg.req_method, msg.head.subject) { + trace!( + "server body forced to 0; method={:?}, status={:?}", + msg.req_method, + msg.head.subject + ); + encoder = Encoder::length(0); + } + + // cached date is much faster than formatting every request + if !wrote_date { + dst.reserve(date::DATE_VALUE_LENGTH + 8); + extend(dst, b"date: "); + date::extend(dst); + extend(dst, b"\r\n\r\n"); + } else { + extend(dst, b"\r\n"); + } + + ret.map(|()| encoder.set_last(is_last)) + } + + fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + use crate::error::Kind; + let status = match *err.kind() { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + + fn is_server() -> bool { + true + } + + fn update_date() { + date::update(); + } +} + +impl Server { + fn can_have_body(method: &Option<Method>, status: StatusCode) -> bool { + Server::can_chunked(method, status) + } + + fn can_chunked(method: &Option<Method>, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() + { + false + } else { + match status { + // TODO: support for 1xx codes needs improvement everywhere + // would be 100...199 => false + StatusCode::SWITCHING_PROTOCOLS + | StatusCode::NO_CONTENT + | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } +} + +impl Http1Transaction for Client { + type Incoming = StatusCode; + type Outgoing = RequestLine; + const LOG: &'static str = "{role=client}"; + + fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> { + // Loop to skip information status code headers (100 Continue, etc). + loop { + if buf.is_empty() { + return Ok(None); + } + // Unsafe: see comment in Server Http1Transaction, above. + let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; + let (len, status, version, headers_len) = { + let mut headers: [httparse::Header<'_>; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + trace!( + "Response.parse([Header; {}], [u8; {}])", + headers.len(), + buf.len() + ); + let mut res = httparse::Response::new(&mut headers); + let bytes = buf.as_ref(); + match res.parse(bytes)? { + httparse::Status::Complete(len) => { + trace!("Response.parse Complete({})", len); + let status = StatusCode::from_u16(res.code.unwrap())?; + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + record_header_indices(bytes, &res.headers, &mut headers_indices)?; + let headers_len = res.headers.len(); + (len, status, version, headers_len) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let slice = buf.split_to(len).freeze(); + + let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); + + let mut keep_alive = version == Version::HTTP_11; + + headers.reserve(headers_len); + for header in &headers_indices[..headers_len] { + let name = header_name!(&slice[header.name.0..header.name.1]); + let value = header_value!(slice.slice(header.value.0..header.value.1)); + + if let header::CONNECTION = name { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + } + headers.append(name, value); + } + + let head = MessageHead { + version, + subject: status, + headers, + }; + if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { + return Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + // a client upgrade means the connection can't be used + // again, as it is definitely upgrading. + keep_alive: keep_alive && !is_upgrade, + wants_upgrade: is_upgrade, + })); + } + } + } + + fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder> { + trace!( + "Client::encode method={:?}, body={:?}", + msg.head.subject.0, + msg.body + ); + + *msg.req_method = Some(msg.head.subject.0.clone()); + + let body = Client::set_length(msg.head, msg.body); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + extend(dst, msg.head.subject.0.as_str().as_bytes()); + extend(dst, b" "); + //TODO: add API to http::Uri to encode without std::fmt + let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1); + + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0"), + Version::HTTP_11 => extend(dst, b"HTTP/1.1"), + Version::HTTP_2 => { + warn!("request with HTTP2 version coerced to HTTP/1.1"); + extend(dst, b"HTTP/1.1"); + } + other => panic!("unexpected request version: {:?}", other), + } + extend(dst, b"\r\n"); + + if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + extend(dst, b"\r\n"); + msg.head.headers.clear(); //TODO: remove when switching to drain() + + Ok(body) + } + + fn on_error(_err: &crate::Error) -> Option<MessageHead<Self::Outgoing>> { + // we can't tell the server about any errors it creates + None + } + + fn is_client() -> bool { + true + } +} + +impl Client { + /// Returns Some(length, wants_upgrade) if successful. + /// + /// Returns None if this message head should be skipped (like a 100 status). + fn decoder( + inc: &MessageHead<StatusCode>, + method: &mut Option<Method>, + ) -> Result<Option<(DecodedLength, bool)>, Parse> { + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. (irrelevant to Response) + // 7. Read till EOF. + + match inc.subject.as_u16() { + 101 => { + return Ok(Some((DecodedLength::ZERO, true))); + } + 100 | 102..=199 => { + trace!("ignoring informational response: {}", inc.subject.as_u16()); + return Ok(None); + } + 204 | 304 => return Ok(Some((DecodedLength::ZERO, false))), + _ => (), + } + match *method { + Some(Method::HEAD) => { + return Ok(Some((DecodedLength::ZERO, false))); + } + Some(Method::CONNECT) => { + if let 200..=299 = inc.subject.as_u16() { + return Ok(Some((DecodedLength::ZERO, true))); + } + } + Some(_) => {} + None => { + trace!("Client::decoder is missing the Method"); + } + } + + if inc.headers.contains_key(header::TRANSFER_ENCODING) { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // malformed. A server should respond with 400 Bad Request. + if inc.version == Version::HTTP_10 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + Err(Parse::Header) + } else if headers::transfer_encoding_is_chunked(&inc.headers) { + Ok(Some((DecodedLength::CHUNKED, false))) + } else { + trace!("not chunked, read till eof"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { + Ok(Some((DecodedLength::checked_new(len)?, false))) + } else if inc.headers.contains_key(header::CONTENT_LENGTH) { + debug!("illegal Content-Length header"); + Err(Parse::Header) + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) + } + } +} + +impl Client { + fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder { + let body = if let Some(body) = body { + body + } else { + head.headers.remove(header::TRANSFER_ENCODING); + return Encoder::length(0); + }; + + // HTTP/1.0 doesn't know about chunked + let can_chunked = head.version == Version::HTTP_11; + let headers = &mut head.headers; + + // If the user already set specific headers, we should respect them, regardless + // of what the Payload knows about itself. They set them for a reason. + + // Because of the borrow checker, we can't check the for an existing + // Content-Length header while holding an `Entry` for the Transfer-Encoding + // header, so unfortunately, we must do the check here, first. + + let existing_con_len = headers::content_length_parse_all(headers); + let mut should_remove_con_len = false; + + if !can_chunked { + // Chunked isn't legal, so if it is set, we need to remove it. + if headers.remove(header::TRANSFER_ENCODING).is_some() { + trace!("removing illegal transfer-encoding header"); + } + + return if let Some(len) = existing_con_len { + Encoder::length(len) + } else if let BodyLength::Known(len) = body { + set_content_length(headers, len) + } else { + // HTTP/1.0 client requests without a content-length + // cannot have any body at all. + Encoder::length(0) + }; + } + + // If the user set a transfer-encoding, respect that. Let's just + // make sure `chunked` is the final encoding. + let encoder = match headers.entry(header::TRANSFER_ENCODING) { + Entry::Occupied(te) => { + should_remove_con_len = true; + if headers::is_chunked(te.iter()) { + Some(Encoder::chunked()) + } else { + warn!("user provided transfer-encoding does not end in 'chunked'"); + + // There's a Transfer-Encoding, but it doesn't end in 'chunked'! + // An example that could trigger this: + // + // Transfer-Encoding: gzip + // + // This can be bad, depending on if this is a request or a + // response. + // + // - A request is illegal if there is a `Transfer-Encoding` + // but it doesn't end in `chunked`. + // - A response that has `Transfer-Encoding` but doesn't + // end in `chunked` isn't illegal, it just forces this + // to be close-delimited. + // + // We can try to repair this, by adding `chunked` ourselves. + + headers::add_chunked(te); + Some(Encoder::chunked()) + } + } + Entry::Vacant(te) => { + if let Some(len) = existing_con_len { + Some(Encoder::length(len)) + } else if let BodyLength::Unknown = body { + // GET, HEAD, and CONNECT almost never have bodies. + // + // So instead of sending a "chunked" body with a 0-chunk, + // assume no body here. If you *must* send a body, + // set the headers explicitly. + match head.subject.0 { + Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)), + _ => { + te.insert(HeaderValue::from_static("chunked")); + Some(Encoder::chunked()) + } + } + } else { + None + } + } + }; + + // This is because we need a second mutable borrow to remove + // content-length header. + if let Some(encoder) = encoder { + if should_remove_con_len && existing_con_len.is_some() { + headers.remove(header::CONTENT_LENGTH); + } + return encoder; + } + + // User didn't set transfer-encoding, AND we know body length, + // so we can just set the Content-Length automatically. + + let len = if let BodyLength::Known(len) = body { + len + } else { + unreachable!("BodyLength::Unknown would set chunked"); + }; + + set_content_length(headers, len) + } +} + +fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { + // At this point, there should not be a valid Content-Length + // header. However, since we'll be indexing in anyways, we can + // warn the user if there was an existing illegal header. + // + // Or at least, we can in theory. It's actually a little bit slower, + // so perhaps only do that while the user is developing/testing. + + if cfg!(debug_assertions) { + match headers.entry(header::CONTENT_LENGTH) { + Entry::Occupied(mut cl) => { + // Internal sanity check, we should have already determined + // that the header was illegal before calling this function. + debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none()); + // Uh oh, the user set `Content-Length` headers, but set bad ones. + // This would be an illegal message anyways, so let's try to repair + // with our known good length. + error!("user provided content-length header was invalid"); + + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + Entry::Vacant(cl) => { + cl.insert(HeaderValue::from(len)); + Encoder::length(len) + } + } + } else { + headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len)); + Encoder::length(len) + } +} + +#[derive(Clone, Copy)] +struct HeaderIndices { + name: (usize, usize), + value: (usize, usize), +} + +fn record_header_indices( + bytes: &[u8], + headers: &[httparse::Header<'_>], + indices: &mut [HeaderIndices], +) -> Result<(), crate::error::Parse> { + let bytes_ptr = bytes.as_ptr() as usize; + + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + if header.name.len() >= (1 << 16) { + debug!("header name larger than 64kb: {:?}", header.name); + return Err(crate::error::Parse::TooLarge); + } + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + indices.name = (name_start, name_end); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + indices.value = (value_start, value_end); + } + + Ok(()) +} + +// Write header names as title case. The header name is assumed to be ASCII, +// therefore it is trivial to convert an ASCII character from lowercase to +// uppercase. It is as simple as XORing the lowercase character byte with +// space. +fn title_case(dst: &mut Vec<u8>, name: &[u8]) { + dst.reserve(name.len()); + + let mut iter = name.iter(); + + // Uppercase the first character + if let Some(c) = iter.next() { + if *c >= b'a' && *c <= b'z' { + dst.push(*c ^ b' '); + } else { + dst.push(*c); + } + } + + while let Some(c) = iter.next() { + dst.push(*c); + + if *c == b'-' { + if let Some(c) = iter.next() { + if *c >= b'a' && *c <= b'z' { + dst.push(*c ^ b' '); + } else { + dst.push(*c); + } + } + } + } +} + +fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + title_case(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) { + for (name, value) in headers { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +struct FastWrite<'a>(&'a mut Vec<u8>); + +impl<'a> fmt::Write for FastWrite<'a> { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + extend(self.0, s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +fn extend(dst: &mut Vec<u8>, data: &[u8]) { + dst.extend_from_slice(data); +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + + use super::*; + + #[test] + fn test_parse_request() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut method = None; + let msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut None, + req_method: &mut method, + }, + ) + .unwrap() + .unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject.0, crate::Method::GET); + assert_eq!(msg.head.subject.1, "/echo"); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Host"], "hyper.rs"); + assert_eq!(method, Some(crate::Method::GET)); + } + + #[test] + fn test_parse_response() { + let _ = pretty_env_logger::try_init(); + let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(crate::Method::GET), + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, crate::StatusCode::OK); + assert_eq!(msg.head.version, crate::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Content-Length"], "0"); + } + + #[test] + fn test_parse_request_errors() { + let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + Server::parse(&mut raw, ctx).unwrap_err(); + } + + #[test] + fn test_decoder_request() { + fn parse(s: &str) -> ParsedMessage<RequestLine> { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str, comment: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }, + ) + .expect_err(comment) + } + + // no length or transfer-encoding means 0-length body + assert_eq!( + parse( + "\ + GET / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip, chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // content-length + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // multiple content-lengths of same value are fine + assert_eq!( + parse( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // multiple content-lengths with different values is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 11\r\n\ + \r\n\ + ", + "multiple content-lengths", + ); + + // transfer-encoding that isn't chunked is an error + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + \r\n\ + ", + "transfer-encoding but not chunked", + ); + + parse_err( + "\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked, gzip\r\n\ + \r\n\ + ", + "transfer-encoding doesn't end in chunked", + ); + + // http/1.0 + + assert_eq!( + parse( + "\ + POST / HTTP/1.0\r\n\ + content-length: 10\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(10) + ); + + // 1.0 doesn't understand chunked, so its an error + parse_err( + "\ + POST / HTTP/1.0\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + "1.0 chunked", + ); + } + + #[test] + fn test_decoder_response() { + fn parse(s: &str) -> ParsedMessage<StatusCode> { + parse_with_method(s, Method::GET) + } + + fn parse_ignores(s: &str) { + let mut bytes = BytesMut::from(s); + assert!(Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + } + ) + .expect("parse ok") + .is_none()) + } + + fn parse_with_method(s: &str, m: Method) -> ParsedMessage<StatusCode> { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(m), + }, + ) + .expect("parse ok") + .expect("parse complete") + } + + fn parse_err(s: &str) -> crate::error::Parse { + let mut bytes = BytesMut::from(s); + Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }, + ) + .expect_err("parse should err") + } + + // no content-length or transfer-encoding means close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 204 and 304 never have a body + assert_eq!( + parse( + "\ + HTTP/1.1 204 No Content\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 304 Not Modified\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::ZERO + ); + + // content-length + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 8\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::new(8) + ); + + parse_err( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 9\r\n\ + \r\n\ + ", + ); + + // transfer-encoding: chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // transfer-encoding not-chunked is close-delimited + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: yolo\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // transfer-encoding and content-length = chunked + assert_eq!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CHUNKED + ); + + // HEAD can have content-length, but not body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ", + Method::HEAD + ) + .decode, + DecodedLength::ZERO + ); + + // CONNECT with 200 never has body + { + let msg = parse_with_method( + "\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", + Method::CONNECT, + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be upgrade"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // CONNECT receiving non 200 can have a body + assert_eq!( + parse_with_method( + "\ + HTTP/1.1 400 Bad Request\r\n\ + \r\n\ + ", + Method::CONNECT + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1xx status codes + parse_ignores( + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + ", + ); + + parse_ignores( + "\ + HTTP/1.1 103 Early Hints\r\n\ + \r\n\ + ", + ); + + // 101 upgrade not supported yet + { + let msg = parse( + "\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + ", + ); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be last"); + assert!(msg.wants_upgrade, "should be upgrade"); + } + + // http/1.0 + assert_eq!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + \r\n\ + " + ) + .decode, + DecodedLength::CLOSE_DELIMITED + ); + + // 1.0 doesn't understand chunked + parse_err( + "\ + HTTP/1.0 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", + ); + + // keep-alive + assert!( + parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.1 keep-alive is default" + ); + + assert!( + !parse( + "\ + HTTP/1.1 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, close, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection close is always close" + ); + + assert!( + !parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + \r\n\ + " + ) + .keep_alive, + "HTTP/1.0 close is default" + ); + + assert!( + parse( + "\ + HTTP/1.0 200 OK\r\n\ + content-length: 0\r\n\ + connection: foo, keep-alive, bar\r\n\ + \r\n\ + " + ) + .keep_alive, + "connection keep-alive is always keep-alive" + ); + } + + #[test] + fn test_client_request_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + head.headers.insert("*-*", HeaderValue::from_static("o_o")); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec()); + } + + #[test] + fn test_server_encode_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let encoder = Server::encode( + Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert!(encoder.is_last()); + } + + #[test] + fn parse_header_htabs() { + let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n"); + let parsed = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }, + ) + .expect("parse ok") + .expect("parse complete"); + + assert_eq!(parsed.head.headers["server"], "hello\tworld"); + } + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_incoming(b: &mut Bencher) { + let mut raw = BytesMut::from( + &b"GET /super_long_uri/and_whatever?what_should_we_talk_about/\ + I_wonder/Hard_to_write_in_an_uri_after_all/you_have_to_make\ + _up_the_punctuation_yourself/how_fun_is_that?test=foo&test1=\ + foo1&test2=foo2&test3=foo3&test4=foo4 HTTP/1.1\r\nHost: \ + hyper.rs\r\nAccept: a lot of things\r\nAccept-Charset: \ + utf8\r\nAccept-Encoding: *\r\nAccess-Control-Allow-\ + Credentials: None\r\nAccess-Control-Allow-Origin: None\r\n\ + Access-Control-Allow-Methods: None\r\nAccess-Control-Allow-\ + Headers: None\r\nContent-Encoding: utf8\r\nContent-Security-\ + Policy: None\r\nContent-Type: text/html\r\nOrigin: hyper\ + \r\nSec-Websocket-Extensions: It looks super important!\r\n\ + Sec-Websocket-Origin: hyper\r\nSec-Websocket-Version: 4.3\r\ + \nStrict-Transport-Security: None\r\nUser-Agent: hyper\r\n\ + X-Content-Duration: None\r\nX-Content-Security-Policy: None\ + \r\nX-DNSPrefetch-Control: None\r\nX-Frame-Options: \ + Something important obviously\r\nX-Requested-With: Nothing\ + \r\n\r\n"[..], + ); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_short(b: &mut Bencher) { + let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..]; + let mut raw = BytesMut::from(s); + let len = raw.len(); + let mut headers = Some(HeaderMap::new()); + + b.bytes = len as u64; + b.iter(|| { + let mut msg = Server::parse( + &mut raw, + ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }, + ) + .unwrap() + .unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); + headers = Some(msg.head.headers); + restart(&mut raw, len); + }); + + fn restart(b: &mut BytesMut, len: usize) { + b.reserve(1); + unsafe { + b.set_len(len); + } + } + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_headers_preset(b: &mut Bencher) { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let len = 108; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut headers = HeaderMap::new(); + headers.insert("content-length", HeaderValue::from_static("10")); + headers.insert("content-type", HeaderValue::from_static("application/json")); + + b.iter(|| { + let mut vec = Vec::new(); + head.headers = headers.clone(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(vec); + }) + } + + #[cfg(feature = "nightly")] + #[bench] + fn bench_server_encode_no_headers(b: &mut Bencher) { + use crate::proto::BodyLength; + + let len = 76; + b.bytes = len as u64; + + let mut head = MessageHead::default(); + let mut vec = Vec::with_capacity(128); + + b.iter(|| { + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + assert_eq!(vec.len(), len); + ::test::black_box(&vec); + + vec.clear(); + }) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/client.rs b/third_party/rust/hyper/src/proto/h2/client.rs new file mode 100644 index 0000000000..bf4cfccea5 --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/client.rs @@ -0,0 +1,292 @@ +#[cfg(feature = "runtime")] +use std::time::Duration; + +use futures_channel::{mpsc, oneshot}; +use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; +use futures_util::stream::StreamExt as _; +use h2::client::{Builder, SendRequest}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use crate::body::Payload; +use crate::common::{task, Exec, Future, Never, Pin, Poll}; +use crate::headers; +use crate::proto::Dispatched; +use crate::{Body, Request, Response}; + +type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>; + +///// An mpsc channel is used to help notify the `Connection` task when *all* +///// other handles to it have been dropped, so that it can shutdown. +type ConnDropRef = mpsc::Sender<Never>; + +///// A oneshot channel watches the `Connection` task, and when it completes, +///// the "dispatch" task will be notified and can shutdown sooner. +type ConnEof = oneshot::Receiver<Never>; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024 * 5; // 5mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024 * 2; // 2mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_while_idle: bool, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + #[cfg(feature = "runtime")] + keep_alive_while_idle: false, + } + } +} + +pub(crate) async fn handshake<T, B>( + io: T, + req_rx: ClientRx<B>, + config: &Config, + exec: Exec, +) -> crate::Result<ClientTask<B>> +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: Payload, +{ + let (h2_tx, mut conn) = Builder::default() + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size) + .enable_push(false) + .handshake::<_, SendBuf<B::Data>>(io) + .await + .map_err(crate::Error::new_h2)?; + + // An mpsc channel is used entirely to detect when the + // 'Client' has been dropped. This is to get around a bug + // in h2 where dropping all SendRequests won't notify a + // parked Connection. + let (conn_drop_ref, rx) = mpsc::channel(1); + let (cancel_tx, conn_eof) = oneshot::channel(); + + let conn_drop_rx = rx.into_future().map(|(item, _rx)| { + if let Some(never) = item { + match never {} + } + }); + + let ping_config = ping::Config { + bdp_initial_window: if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + #[cfg(feature = "runtime")] + keep_alive_while_idle: config.keep_alive_while_idle, + }; + + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + let (recorder, mut ponger) = ping::channel(pp, ping_config); + + let conn = future::poll_fn(move |cx| { + match ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + conn.set_target_window_size(wnd); + conn.set_initial_window_size(wnd)?; + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } + Poll::Pending => {} + } + + Pin::new(&mut conn).poll(cx) + }); + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + recorder + } else { + let conn = conn.map_err(|e| debug!("connection error: {}", e)); + + exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); + ping::disabled() + }; + + Ok(ClientTask { + ping, + conn_drop_ref, + conn_eof, + executor: exec, + h2_tx, + req_rx, + }) +} + +async fn conn_task<C, D>(conn: C, drop_rx: D, cancel_tx: oneshot::Sender<Never>) +where + C: Future + Unpin, + D: Future<Output = ()> + Unpin, +{ + match future::select(conn, drop_rx).await { + Either::Left(_) => { + // ok or err, the `conn` has finished + } + Either::Right(((), conn)) => { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close + trace!("send_request dropped, starting conn shutdown"); + drop(cancel_tx); + let _ = conn.await; + } + } +} + +pub(crate) struct ClientTask<B> +where + B: Payload, +{ + ping: ping::Recorder, + conn_drop_ref: ConnDropRef, + conn_eof: ConnEof, + executor: Exec, + h2_tx: SendRequest<SendBuf<B::Data>>, + req_rx: ClientRx<B>, +} + +impl<B> Future for ClientTask<B> +where + B: Payload + 'static, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + loop { + match ready!(self.h2_tx.poll_ready(cx)) { + Ok(()) => (), + Err(err) => { + self.ping.ensure_not_timed_out()?; + return if err.reason() == Some(::h2::Reason::NO_ERROR) { + trace!("connection gracefully shutdown"); + Poll::Ready(Ok(Dispatched::Shutdown)) + } else { + Poll::Ready(Err(crate::Error::new_h2(err))) + }; + } + }; + + match Pin::new(&mut self.req_rx).poll_next(cx) { + Poll::Ready(Some((req, cb))) => { + // check that future hasn't been canceled already + if cb.is_canceled() { + trace!("request callback is canceled"); + continue; + } + let (head, body) = req.into_parts(); + let mut req = ::http::Request::from_parts(head, ()); + super::strip_connection_headers(req.headers_mut(), true); + if let Some(len) = body.size_hint().exact() { + if len != 0 || headers::method_has_defined_payload_semantics(req.method()) { + headers::set_content_length_if_missing(req.headers_mut(), len); + } + } + let eos = body.is_end_stream(); + let (fut, body_tx) = match self.h2_tx.send_request(req, eos) { + Ok(ok) => ok, + Err(err) => { + debug!("client send request error: {}", err); + cb.send(Err((crate::Error::new_h2(err), None))); + continue; + } + }; + + let ping = self.ping.clone(); + if !eos { + let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| { + if let Err(e) = res { + debug!("client request body error: {}", e); + } + }); + + // eagerly see if the body pipe is ready and + // can thus skip allocating in the executor + match Pin::new(&mut pipe).poll(cx) { + Poll::Ready(_) => (), + Poll::Pending => { + let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); + let pipe = pipe.map(move |x| { + drop(conn_drop_ref); + drop(ping); + x + }); + self.executor.execute(pipe); + } + } + } + + let fut = fut.map(move |result| match result { + Ok(res) => { + // record that we got the response headers + ping.record_non_data(); + + let content_length = decode_content_length(res.headers()); + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length, ping) + }); + Ok(res) + } + Err(err) => { + ping.ensure_not_timed_out().map_err(|e| (e, None))?; + + debug!("client response error: {}", err); + Err((crate::Error::new_h2(err), None)) + } + }); + self.executor.execute(cb.send_when(fut)); + continue; + } + + Poll::Ready(None) => { + trace!("client::dispatch::Sender dropped"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + + Poll::Pending => match ready!(Pin::new(&mut self.conn_eof).poll(cx)) { + Ok(never) => match never {}, + Err(_conn_is_eof) => { + trace!("connection task is closed, closing dispatch task"); + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }, + } + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/mod.rs b/third_party/rust/hyper/src/proto/h2/mod.rs new file mode 100644 index 0000000000..e25f038cad --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/mod.rs @@ -0,0 +1,263 @@ +use bytes::Buf; +use h2::SendStream; +use http::header::{ + HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, + TRANSFER_ENCODING, UPGRADE, +}; +use http::HeaderMap; +use pin_project::pin_project; + +use super::DecodedLength; +use crate::body::Payload; +use crate::common::{task, Future, Pin, Poll}; +use crate::headers::content_length_parse_all; + +pub(crate) mod client; +pub(crate) mod ping; +pub(crate) mod server; + +pub(crate) use self::client::ClientTask; +pub(crate) use self::server::Server; + +/// Default initial stream window size defined in HTTP2 spec. +pub(crate) const SPEC_WINDOW_SIZE: u32 = 65_535; + +fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { + // List of connection headers from: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + // + // TE headers are allowed in HTTP/2 requests as long as the value is "trailers", so they're + // tested separately. + let connection_headers = [ + HeaderName::from_lowercase(b"keep-alive").unwrap(), + HeaderName::from_lowercase(b"proxy-connection").unwrap(), + PROXY_AUTHENTICATE, + PROXY_AUTHORIZATION, + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + ]; + + for header in connection_headers.iter() { + if headers.remove(header).is_some() { + warn!("Connection header illegal in HTTP/2: {}", header.as_str()); + } + } + + if is_request { + if headers + .get(TE) + .map(|te_header| te_header != "trailers") + .unwrap_or(false) + { + warn!("TE headers not set to \"trailers\" are illegal in HTTP/2 requests"); + headers.remove(TE); + } + } else if headers.remove(TE).is_some() { + warn!("TE headers illegal in HTTP/2 responses"); + } + + if let Some(header) = headers.remove(CONNECTION) { + warn!( + "Connection header illegal in HTTP/2: {}", + CONNECTION.as_str() + ); + let header_contents = header.to_str().unwrap(); + + // A `Connection` header may have a comma-separated list of names of other headers that + // are meant for only this specific connection. + // + // Iterate these names and remove them as headers. Connection-specific headers are + // forbidden in HTTP2, as that information has been moved into frame types of the h2 + // protocol. + for name in header_contents.split(',') { + let name = name.trim(); + headers.remove(name); + } + } +} + +fn decode_content_length(headers: &HeaderMap) -> DecodedLength { + if let Some(len) = content_length_parse_all(headers) { + // If the length is u64::MAX, oh well, just reported chunked. + DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) + } else { + DecodedLength::CHUNKED + } +} + +// body adapters used by both Client and Server + +#[pin_project] +struct PipeToSendStream<S> +where + S: Payload, +{ + body_tx: SendStream<SendBuf<S::Data>>, + data_done: bool, + #[pin] + stream: S, +} + +impl<S> PipeToSendStream<S> +where + S: Payload, +{ + fn new(stream: S, tx: SendStream<SendBuf<S::Data>>) -> PipeToSendStream<S> { + PipeToSendStream { + body_tx: tx, + data_done: false, + stream, + } + } +} + +impl<S> Future for PipeToSendStream<S> +where + S: Payload, +{ + type Output = crate::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + loop { + if !*me.data_done { + // we don't have the next chunk of data yet, so just reserve 1 byte to make + // sure there's some capacity available. h2 will handle the capacity management + // for the actual body chunk. + me.body_tx.reserve_capacity(1); + + if me.body_tx.capacity() == 0 { + loop { + match ready!(me.body_tx.poll_capacity(cx)) { + Some(Ok(0)) => {} + Some(Ok(_)) => break, + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_body_write(e))) + } + None => { + // None means the stream is no longer in a + // streaming state, we either finished it + // somehow, or the remote reset us. + return Poll::Ready(Err(crate::Error::new_body_write( + "send stream capacity unexpectedly closed", + ))); + } + } + } + } else if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.as_mut().poll_data(cx)) { + Some(Ok(chunk)) => { + let is_eos = me.stream.is_end_stream(); + trace!( + "send body chunk: {} bytes, eos={}", + chunk.remaining(), + is_eos, + ); + + let buf = SendBuf(Some(chunk)); + me.body_tx + .send_data(buf, is_eos) + .map_err(crate::Error::new_body_write)?; + + if is_eos { + return Poll::Ready(Ok(())); + } + } + Some(Err(e)) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + None => { + me.body_tx.reserve_capacity(0); + let is_eos = me.stream.is_end_stream(); + if is_eos { + return Poll::Ready(me.body_tx.send_eos_frame()); + } else { + *me.data_done = true; + // loop again to poll_trailers + } + } + } + } else { + if let Poll::Ready(reason) = me + .body_tx + .poll_reset(cx) + .map_err(crate::Error::new_body_write)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + + match ready!(me.stream.poll_trailers(cx)) { + Ok(Some(trailers)) => { + me.body_tx + .send_trailers(trailers) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + Ok(None) => { + // There were no trailers, so send an empty DATA frame... + return Poll::Ready(me.body_tx.send_eos_frame()); + } + Err(e) => return Poll::Ready(Err(me.body_tx.on_user_err(e))), + } + } + } + } +} + +trait SendStreamExt { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>; + fn send_eos_frame(&mut self) -> crate::Result<()>; +} + +impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> { + fn on_user_err<E>(&mut self, err: E) -> crate::Error + where + E: Into<Box<dyn std::error::Error + Send + Sync>>, + { + let err = crate::Error::new_user_body(err); + debug!("send body user stream error: {}", err); + self.send_reset(err.h2_reason()); + err + } + + fn send_eos_frame(&mut self) -> crate::Result<()> { + trace!("send body eos"); + self.send_data(SendBuf(None), true) + .map_err(crate::Error::new_body_write) + } +} + +struct SendBuf<B>(Option<B>); + +impl<B: Buf> Buf for SendBuf<B> { + #[inline] + fn remaining(&self) -> usize { + self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + } + + #[inline] + fn bytes(&self) -> &[u8] { + self.0.as_ref().map(|b| b.bytes()).unwrap_or(&[]) + } + + #[inline] + fn advance(&mut self, cnt: usize) { + if let Some(b) = self.0.as_mut() { + b.advance(cnt) + } + } +} diff --git a/third_party/rust/hyper/src/proto/h2/ping.rs b/third_party/rust/hyper/src/proto/h2/ping.rs new file mode 100644 index 0000000000..c4fe2dd15c --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/ping.rs @@ -0,0 +1,506 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. + +#[cfg(feature = "runtime")] +use std::fmt; +#[cfg(feature = "runtime")] +use std::future::Future; +#[cfg(feature = "runtime")] +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(not(feature = "runtime"))] +use std::time::Instant; + +use h2::{Ping, PingPong}; +#[cfg(feature = "runtime")] +use tokio::time::{Delay, Instant}; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + rtt: 0.0, + }); + + let bytes = bdp.as_ref().map(|_| 0); + + #[cfg(feature = "runtime")] + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + timer: tokio::time::delay_for(interval), + state: KeepAliveState::Init, + }); + + #[cfg(feature = "runtime")] + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + #[cfg(feature = "runtime")] + last_read_at, + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + #[cfg(feature = "runtime")] + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option<WindowSize>, + /// If no frames are received in this amount of time, a PING frame is sent. + #[cfg(feature = "runtime")] + pub(super) keep_alive_interval: Option<Duration>, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + #[cfg(feature = "runtime")] + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + #[cfg(feature = "runtime")] + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option<Arc<Mutex<Shared>>>, +} + +pub(super) struct Ponger { + bdp: Option<Bdp>, + #[cfg(feature = "runtime")] + keep_alive: Option<KeepAlive>, + shared: Arc<Mutex<Shared>>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option<Instant>, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option<usize>, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + #[cfg(feature = "runtime")] + last_read_at: Option<Instant>, + + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Round trip time in seconds + rtt: f64, +} + +#[cfg(feature = "runtime")] +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + + state: KeepAliveState, + timer: Delay, +} + +#[cfg(feature = "runtime")] +enum KeepAliveState { + Init, + Scheduled, + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + #[cfg(feature = "runtime")] + KeepAliveTimedOut, +} + +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + #[cfg(feature = "runtime")] + { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } + + #[cfg(not(feature = "runtime"))] + { + self.bdp_initial_window.is_some() + } + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + #[cfg(feature = "runtime")] + locked.update_last_read_at(); + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + #[cfg(feature = "runtime")] + { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + #[cfg(feature = "runtime")] + { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<Ponged> { + let mut locked = self.shared.lock().unwrap(); + #[cfg(feature = "runtime")] + let is_idle = self.is_idle(); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + ka.schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + let (bytes, rtt) = match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let rtt = locked + .ping_sent_at + .expect("pong received implies ping_sent_at") + .elapsed(); + locked.ping_sent_at = None; + trace!("recv pong"); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.schedule(is_idle, &locked); + } + } + + if self.bdp.is_some() { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + (bytes, rtt) + } else { + // no bdp, done! + return Poll::Pending; + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + return Poll::Pending; + } + Poll::Pending => { + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + + return Poll::Pending; + } + }; + + drop(locked); + + if let Some(bdp) = self.bdp.as_mut().and_then(|bdp| bdp.calculate(bytes, rtt)) { + Poll::Ready(Ponged::SizeUpdate(bdp)) + } else { + // XXX: this doesn't register a waker...? + Poll::Pending + } + } + + #[cfg(feature = "runtime")] + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + #[cfg(feature = "runtime")] + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + #[cfg(feature = "runtime")] + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option<WindowSize> { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.rtt == 0.0 { + // First sample means rtt is first rtt. + self.rtt = rtt; + } else { + // Weigh this rtt as 1/8 for a moving average. + self.rtt += (rtt - self.rtt) * 0.125; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if bytes >= self.bdp as usize * 2 / 3 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + Some(self.bdp) + } else { + None + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +#[cfg(feature = "runtime")] +impl KeepAlive { + fn schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.reset(interval); + } + KeepAliveState::Scheduled | KeepAliveState::PingSent => (), + } + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > self.timer.deadline() { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.reset(timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +#[cfg(feature = "runtime")] +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +#[cfg(feature = "runtime")] +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +#[cfg(feature = "runtime")] +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/third_party/rust/hyper/src/proto/h2/server.rs b/third_party/rust/hyper/src/proto/h2/server.rs new file mode 100644 index 0000000000..bf81c1190f --- /dev/null +++ b/third_party/rust/hyper/src/proto/h2/server.rs @@ -0,0 +1,439 @@ +use std::error::Error as StdError; +use std::marker::Unpin; +#[cfg(feature = "runtime")] +use std::time::Duration; + +use h2::server::{Connection, Handshake, SendResponse}; +use h2::Reason; +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use crate::body::Payload; +use crate::common::exec::H2Exec; +use crate::common::{task, Future, Pin, Poll}; +use crate::headers; +use crate::proto::Dispatched; +use crate::service::HttpService; + +use crate::{Body, Response}; + +// Our defaults are chosen for the "majority" case, which usually are not +// resource constrained, and so the spec default of 64kb can be too limiting +// for performance. +// +// At the same time, a server more often has multiple clients connected, and +// so is more likely to use more resources than a client would. +const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb + +#[derive(Clone, Debug)] +pub(crate) struct Config { + pub(crate) adaptive_window: bool, + pub(crate) initial_conn_window_size: u32, + pub(crate) initial_stream_window_size: u32, + pub(crate) max_concurrent_streams: Option<u32>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option<Duration>, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, +} + +impl Default for Config { + fn default() -> Config { + Config { + adaptive_window: false, + initial_conn_window_size: DEFAULT_CONN_WINDOW, + initial_stream_window_size: DEFAULT_STREAM_WINDOW, + max_concurrent_streams: None, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + } + } +} + +#[pin_project] +pub(crate) struct Server<T, S, B, E> +where + S: HttpService<Body>, + B: Payload, +{ + exec: E, + service: S, + state: State<T, B>, +} + +enum State<T, B> +where + B: Payload, +{ + Handshaking { + ping_config: ping::Config, + hs: Handshake<T, SendBuf<B::Data>>, + }, + Serving(Serving<T, B>), + Closed, +} + +struct Serving<T, B> +where + B: Payload, +{ + ping: Option<(ping::Recorder, ping::Ponger)>, + conn: Connection<T, SendBuf<B::Data>>, + closing: Option<crate::Error>, +} + +impl<T, S, B, E> Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server<T, S, B, E> { + let mut builder = h2::server::Builder::default(); + builder + .initial_window_size(config.initial_stream_window_size) + .initial_connection_window_size(config.initial_conn_window_size); + if let Some(max) = config.max_concurrent_streams { + builder.max_concurrent_streams(max); + } + let handshake = builder.handshake(io); + + let bdp = if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }; + + let ping_config = ping::Config { + bdp_initial_window: bdp, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + // If keep-alive is enabled for servers, always enabled while + // idle, so it can more aggresively close dead connections. + #[cfg(feature = "runtime")] + keep_alive_while_idle: true, + }; + + Server { + exec, + state: State::Handshaking { + ping_config, + hs: handshake, + }, + service, + } + } + + pub fn graceful_shutdown(&mut self) { + trace!("graceful_shutdown"); + match self.state { + State::Handshaking { .. } => { + // fall-through, to replace state with Closed + } + State::Serving(ref mut srv) => { + if srv.closing.is_none() { + srv.conn.graceful_shutdown(); + } + return; + } + State::Closed => { + return; + } + } + self.state = State::Closed; + } +} + +impl<T, S, B, E> Future for Server<T, S, B, E> +where + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Payload, + E: H2Exec<S::Future, B>, +{ + type Output = crate::Result<Dispatched>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = &mut *self; + loop { + let next = match me.state { + State::Handshaking { + ref mut hs, + ref ping_config, + } => { + let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + Some(ping::channel(pp, ping_config.clone())) + } else { + None + }; + State::Serving(Serving { + ping, + conn, + closing: None, + }) + } + State::Serving(ref mut srv) => { + ready!(srv.poll_server(cx, &mut me.service, &mut me.exec))?; + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + State::Closed => { + // graceful_shutdown was called before handshaking finished, + // nothing to do here... + return Poll::Ready(Ok(Dispatched::Shutdown)); + } + }; + me.state = next; + } + } +} + +impl<T, B> Serving<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Payload, +{ + fn poll_server<S, E>( + &mut self, + cx: &mut task::Context<'_>, + service: &mut S, + exec: &mut E, + ) -> Poll<crate::Result<()>> + where + S: HttpService<Body, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + E: H2Exec<S::Future, B>, + { + if self.closing.is_none() { + loop { + self.poll_ping(cx); + + // Check that the service is ready to accept a new request. + // + // - If not, just drive the connection some. + // - If ready, try to accept a new request from the connection. + match service.poll_ready(cx) { + Poll::Ready(Ok(())) => (), + Poll::Pending => { + // use `poll_closed` instead of `poll_accept`, + // in order to avoid accepting a request. + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => { + let err = crate::Error::new_user_service(err); + debug!("service closed: {}", err); + + let reason = err.h2_reason(); + if reason == Reason::NO_ERROR { + // NO_ERROR is only used for graceful shutdowns... + trace!("interpretting NO_ERROR user error as graceful_shutdown"); + self.conn.graceful_shutdown(); + } else { + trace!("abruptly shutting down with {:?}", reason); + self.conn.abrupt_shutdown(reason); + } + self.closing = Some(err); + break; + } + } + + // When the service is ready, accepts an incoming request. + match ready!(self.conn.poll_accept(cx)) { + Some(Ok((req, respond))) => { + trace!("incoming request"); + let content_length = decode_content_length(req.headers()); + let ping = self + .ping + .as_ref() + .map(|ping| ping.0.clone()) + .unwrap_or_else(ping::disabled); + + // Record the headers received + ping.record_non_data(); + + let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); + let fut = H2Stream::new(service.call(req), respond); + exec.execute_h2stream(fut); + } + Some(Err(e)) => { + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + None => { + // no more incoming streams... + if let Some((ref ping, _)) = self.ping { + ping.ensure_not_timed_out()?; + } + + trace!("incoming connection complete"); + return Poll::Ready(Ok(())); + } + } + } + } + + debug_assert!( + self.closing.is_some(), + "poll_server broke loop without closing" + ); + + ready!(self.conn.poll_closed(cx).map_err(crate::Error::new_h2))?; + + Poll::Ready(Err(self.closing.take().expect("polled after error"))) + } + + fn poll_ping(&mut self, cx: &mut task::Context<'_>) { + if let Some((_, ref mut estimator)) = self.ping { + match estimator.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { + self.conn.set_target_window_size(wnd); + let _ = self.conn.set_initial_window_size(wnd); + } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("keep-alive timed out, closing connection"); + self.conn.abrupt_shutdown(h2::Reason::NO_ERROR); + } + Poll::Pending => {} + } + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct H2Stream<F, B> +where + B: Payload, +{ + reply: SendResponse<SendBuf<B::Data>>, + #[pin] + state: H2StreamState<F, B>, +} + +#[pin_project] +enum H2StreamState<F, B> +where + B: Payload, +{ + Service(#[pin] F), + Body(#[pin] PipeToSendStream<B>), +} + +impl<F, B> H2Stream<F, B> +where + B: Payload, +{ + fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> { + H2Stream { + reply: respond, + state: H2StreamState::Service(fut), + } + } +} + +macro_rules! reply { + ($me:expr, $res:expr, $eos:expr) => {{ + match $me.reply.send_response($res, $eos) { + Ok(tx) => tx, + Err(e) => { + debug!("send response error: {}", e); + $me.reply.send_reset(Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + } + }}; +} + +impl<F, B, E> H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: Payload, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + #[project] + fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { + let mut me = self.project(); + loop { + #[project] + let next = match me.state.as_mut().project() { + H2StreamState::Service(h) => { + let res = match h.poll(cx) { + Poll::Ready(Ok(r)) => r, + Poll::Pending => { + // Response is not yet ready, so we want to check if the client has sent a + // RST_STREAM frame which would cancel the current request. + if let Poll::Ready(reason) = + me.reply.poll_reset(cx).map_err(crate::Error::new_h2)? + { + debug!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_h2(reason.into()))); + } + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let err = crate::Error::new_user_service(e); + warn!("http2 service errored: {}", err); + me.reply.send_reset(err.h2_reason()); + return Poll::Ready(Err(err)); + } + }; + + let (head, body) = res.into_parts(); + let mut res = ::http::Response::from_parts(head, ()); + super::strip_connection_headers(res.headers_mut(), false); + + // set Date header if it isn't already set... + res.headers_mut() + .entry(::http::header::DATE) + .or_insert_with(crate::proto::h1::date::update_and_header_value); + + // automatically set Content-Length from body... + if let Some(len) = body.size_hint().exact() { + headers::set_content_length_if_missing(res.headers_mut(), len); + } + + if !body.is_end_stream() { + let body_tx = reply!(me, res, false); + H2StreamState::Body(PipeToSendStream::new(body, body_tx)) + } else { + reply!(me, res, true); + return Poll::Ready(Ok(())); + } + } + H2StreamState::Body(pipe) => { + return pipe.poll(cx); + } + }; + me.state.set(next); + } + } +} + +impl<F, B, E> Future for H2Stream<F, B> +where + F: Future<Output = Result<Response<B>, E>>, + B: Payload, + E: Into<Box<dyn StdError + Send + Sync>>, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.poll2(cx).map(|res| { + if let Err(e) = res { + debug!("stream error: {}", e); + } + }) + } +} diff --git a/third_party/rust/hyper/src/proto/mod.rs b/third_party/rust/hyper/src/proto/mod.rs new file mode 100644 index 0000000000..7268e21265 --- /dev/null +++ b/third_party/rust/hyper/src/proto/mod.rs @@ -0,0 +1,145 @@ +//! Pieces pertaining to the HTTP message protocol. +use http::{HeaderMap, Method, StatusCode, Uri, Version}; + +pub(crate) use self::body_length::DecodedLength; +pub(crate) use self::h1::{dispatch, Conn, ServerTransaction}; + +pub(crate) mod h1; +pub(crate) mod h2; + +/// An Incoming Message head. Includes request/status line, and headers. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct MessageHead<S> { + /// HTTP version of the message. + pub version: Version, + /// Subject (request line or status line) of Incoming message. + pub subject: S, + /// Headers of the Incoming message. + pub headers: HeaderMap, +} + +/// An incoming request message. +pub type RequestHead = MessageHead<RequestLine>; + +#[derive(Debug, Default, PartialEq)] +pub struct RequestLine(pub Method, pub Uri); + +/// An incoming response message. +pub type ResponseHead = MessageHead<StatusCode>; + +#[derive(Debug)] +pub enum BodyLength { + /// Content-Length + Known(u64), + /// Transfer-Encoding: chunked (if h1) + Unknown, +} + +/// Status of when a Disaptcher future completes. +pub(crate) enum Dispatched { + /// Dispatcher completely shutdown connection. + Shutdown, + /// Dispatcher has pending upgrade, and so did not shutdown. + Upgrade(crate::upgrade::Pending), +} + +/// A separate module to encapsulate the invariants of the DecodedLength type. +mod body_length { + use std::fmt; + + #[derive(Clone, Copy, PartialEq, Eq)] + pub(crate) struct DecodedLength(u64); + + const MAX_LEN: u64 = std::u64::MAX - 2; + + impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + #[cfg(test)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + + /// Takes the length as a content-length without other checks. + /// + /// Should only be called if previously confirmed this isn't + /// CLOSE_DELIMITED or CHUNKED. + #[inline] + pub(crate) fn danger_len(self) -> u64 { + debug_assert!(self.0 < Self::CHUNKED.0); + self.0 + } + + /// Converts to an Option<u64> representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option<u64> { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } + + /// Checks the `u64` is within the maximum allowed for content-length. + pub(crate) fn checked_new(len: u64) -> Result<Self, crate::error::Parse> { + if len <= MAX_LEN { + Ok(DecodedLength(len)) + } else { + warn!("content-length bigger than maximum: {} > {}", len, MAX_LEN); + Err(crate::error::Parse::TooLarge) + } + } + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + } + + impl fmt::Debug for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("CLOSE_DELIMITED"), + DecodedLength::CHUNKED => f.write_str("CHUNKED"), + DecodedLength(n) => f.debug_tuple("DecodedLength").field(&n).finish(), + } + } + } + + impl fmt::Display for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("close-delimited"), + DecodedLength::CHUNKED => f.write_str("chunked encoding"), + DecodedLength::ZERO => f.write_str("empty"), + DecodedLength(n) => write!(f, "content-length ({} bytes)", n), + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn sub_if_known() { + let mut len = DecodedLength::new(30); + len.sub_if(20); + + assert_eq!(len.0, 10); + } + + #[test] + fn sub_if_chunked() { + let mut len = DecodedLength::CHUNKED; + len.sub_if(20); + + assert_eq!(len, DecodedLength::CHUNKED); + } + } +} |