diff options
Diffstat (limited to 'third_party/rust/h2/src/codec')
-rw-r--r-- | third_party/rust/h2/src/codec/error.rs | 102 | ||||
-rw-r--r-- | third_party/rust/h2/src/codec/framed_read.rs | 415 | ||||
-rw-r--r-- | third_party/rust/h2/src/codec/framed_write.rs | 374 | ||||
-rw-r--r-- | third_party/rust/h2/src/codec/mod.rs | 201 |
4 files changed, 1092 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/codec/error.rs b/third_party/rust/h2/src/codec/error.rs new file mode 100644 index 0000000000..0acb913e52 --- /dev/null +++ b/third_party/rust/h2/src/codec/error.rs @@ -0,0 +1,102 @@ +use crate::proto::Error; + +use std::{error, fmt, io}; + +/// Errors caused by sending a message +#[derive(Debug)] +pub enum SendError { + Connection(Error), + User(UserError), +} + +/// Errors caused by users of the library +#[derive(Debug)] +pub enum UserError { + /// The stream ID is no longer accepting frames. + InactiveStreamId, + + /// The stream is not currently expecting a frame of this type. + UnexpectedFrameType, + + /// The payload size is too big + PayloadTooBig, + + /// The application attempted to initiate too many streams to remote. + Rejected, + + /// The released capacity is larger than claimed capacity. + ReleaseCapacityTooBig, + + /// The stream ID space is overflowed. + /// + /// A new connection is needed. + OverflowedStreamId, + + /// Illegal headers, such as connection-specific headers. + MalformedHeaders, + + /// Request submitted with relative URI. + MissingUriSchemeAndAuthority, + + /// Calls `SendResponse::poll_reset` after having called `send_response`. + PollResetAfterSendResponse, + + /// Calls `PingPong::send_ping` before receiving a pong. + SendPingWhilePending, + + /// Tries to update local SETTINGS while ACK has not been received. + SendSettingsWhilePending, + + /// Tries to send push promise to peer who has disabled server push + PeerDisabledServerPush, +} + +// ===== impl SendError ===== + +impl error::Error for SendError {} + +impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::Connection(ref e) => e.fmt(fmt), + Self::User(ref e) => e.fmt(fmt), + } + } +} + +impl From<io::Error> for SendError { + fn from(src: io::Error) -> Self { + Self::Connection(src.into()) + } +} + +impl From<UserError> for SendError { + fn from(src: UserError) -> Self { + SendError::User(src) + } +} + +// ===== impl UserError ===== + +impl error::Error for UserError {} + +impl fmt::Display for UserError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use self::UserError::*; + + fmt.write_str(match *self { + InactiveStreamId => "inactive stream", + UnexpectedFrameType => "unexpected frame type", + PayloadTooBig => "payload too big", + Rejected => "rejected", + ReleaseCapacityTooBig => "release capacity too big", + OverflowedStreamId => "stream ID overflowed", + MalformedHeaders => "malformed headers", + MissingUriSchemeAndAuthority => "request URI missing scheme and authority", + PollResetAfterSendResponse => "poll_reset after send_response is illegal", + SendPingWhilePending => "send_ping before received previous pong", + SendSettingsWhilePending => "sending SETTINGS before received previous ACK", + PeerDisabledServerPush => "sending PUSH_PROMISE to peer who disabled server push", + }) + } +} diff --git a/third_party/rust/h2/src/codec/framed_read.rs b/third_party/rust/h2/src/codec/framed_read.rs new file mode 100644 index 0000000000..7c3bbb3ba2 --- /dev/null +++ b/third_party/rust/h2/src/codec/framed_read.rs @@ -0,0 +1,415 @@ +use crate::frame::{self, Frame, Kind, Reason}; +use crate::frame::{ + DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE, +}; +use crate::proto::Error; + +use crate::hpack; + +use futures_core::Stream; + +use bytes::BytesMut; + +use std::io; + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio_util::codec::FramedRead as InnerFramedRead; +use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError}; + +// 16 MB "sane default" taken from golang http2 +const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; + +#[derive(Debug)] +pub struct FramedRead<T> { + inner: InnerFramedRead<T, LengthDelimitedCodec>, + + // hpack decoder state + hpack: hpack::Decoder, + + max_header_list_size: usize, + + partial: Option<Partial>, +} + +/// Partially loaded headers frame +#[derive(Debug)] +struct Partial { + /// Empty frame + frame: Continuable, + + /// Partial header payload + buf: BytesMut, +} + +#[derive(Debug)] +enum Continuable { + Headers(frame::Headers), + PushPromise(frame::PushPromise), +} + +impl<T> FramedRead<T> { + pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> { + FramedRead { + inner, + hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), + max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, + partial: None, + } + } + + pub fn get_ref(&self) -> &T { + self.inner.get_ref() + } + + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut() + } + + /// Returns the current max frame size setting + #[cfg(feature = "unstable")] + #[inline] + pub fn max_frame_size(&self) -> usize { + self.inner.decoder().max_frame_length() + } + + /// Updates the max frame size setting. + /// + /// Must be within 16,384 and 16,777,215. + #[inline] + pub fn set_max_frame_size(&mut self, val: usize) { + assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); + self.inner.decoder_mut().set_max_frame_length(val) + } + + /// Update the max header list size setting. + #[inline] + pub fn set_max_header_list_size(&mut self, val: usize) { + self.max_header_list_size = val; + } +} + +/// Decodes a frame. +/// +/// This method is intentionally de-generified and outlined because it is very large. +fn decode_frame( + hpack: &mut hpack::Decoder, + max_header_list_size: usize, + partial_inout: &mut Option<Partial>, + mut bytes: BytesMut, +) -> Result<Option<Frame>, Error> { + let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); + let _e = span.enter(); + + tracing::trace!("decoding frame from {}B", bytes.len()); + + // Parse the head + let head = frame::Head::parse(&bytes); + + if partial_inout.is_some() && head.kind() != Kind::Continuation { + proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + } + + let kind = head.kind(); + + tracing::trace!(frame.kind = ?kind); + + macro_rules! header_block { + ($frame:ident, $head:ident, $bytes:ident) => ({ + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = $bytes.split_to(frame::HEADER_LEN); + + // Parse the header frame w/o parsing the payload + let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { + Ok(res) => res, + Err(frame::Error::InvalidDependencyId) => { + proto_err!(stream: "invalid HEADERS dependency ID"); + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR)); + }, + Err(e) => { + proto_err!(conn: "failed to load frame; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + }; + + let is_end_headers = frame.is_end_headers(); + + // Load the HPACK encoded headers + match frame.load_hpack(&mut payload, max_header_list_size, hpack) { + Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + Err(frame::Error::MalformedMessage) => { + let id = $head.stream_id(); + proto_err!(stream: "malformed header block; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + }, + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + frame.into() + } else { + tracing::trace!("loaded partial header block"); + // Defer returning the frame + *partial_inout = Some(Partial { + frame: Continuable::$frame(frame), + buf: payload, + }); + + return Ok(None); + } + }); + } + + let frame = match kind { + Kind::Settings => { + let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Ping => { + let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load PING frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::WindowUpdate => { + let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Data => { + let _ = bytes.split_to(frame::HEADER_LEN); + let res = frame::Data::load(head, bytes.freeze()); + + // TODO: Should this always be connection level? Probably not... + res.map_err(|e| { + proto_err!(conn: "failed to load DATA frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Headers => header_block!(Headers, head, bytes), + Kind::Reset => { + let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load RESET frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::GoAway => { + let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); + Error::library_go_away(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::PushPromise => header_block!(PushPromise, head, bytes), + Kind::Priority => { + if head.stream_id() == 0 { + // Invalid stream identifier + proto_err!(conn: "invalid stream ID 0"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + } + + match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { + Ok(frame) => frame.into(), + Err(frame::Error::InvalidDependencyId) => { + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + let id = head.stream_id(); + proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + } + Err(e) => { + proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + } + Kind::Continuation => { + let is_end_headers = (head.flag() & 0x4) == 0x4; + + let mut partial = match partial_inout.take() { + Some(partial) => partial, + None => { + proto_err!(conn: "received unexpected CONTINUATION frame"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + } + }; + + // The stream identifiers must match + if partial.frame.stream_id() != head.stream_id() { + proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + } + + // Extend the buf + if partial.buf.is_empty() { + partial.buf = bytes.split_off(frame::HEADER_LEN); + } else { + if partial.frame.is_over_size() { + // If there was left over bytes previously, they may be + // needed to continue decoding, even though we will + // be ignoring this frame. This is done to keep the HPACK + // decoder state up-to-date. + // + // Still, we need to be careful, because if a malicious + // attacker were to try to send a gigantic string, such + // that it fits over multiple header blocks, we could + // grow memory uncontrollably again, and that'd be a shame. + // + // Instead, we use a simple heuristic to determine if + // we should continue to ignore decoding, or to tell + // the attacker to go away. + if partial.buf.len() + bytes.len() > max_header_list_size { + proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); + return Err(Error::library_go_away(Reason::COMPRESSION_ERROR).into()); + } + } + partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); + } + + match partial + .frame + .load_hpack(&mut partial.buf, max_header_list_size, hpack) + { + Ok(_) => {} + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {} + Err(frame::Error::MalformedMessage) => { + let id = head.stream_id(); + proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); + return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); + } + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + partial.frame.into() + } else { + *partial_inout = Some(partial); + return Ok(None); + } + } + Kind::Unknown => { + // Unknown frames are ignored + return Ok(None); + } + }; + + Ok(Some(frame)) +} + +impl<T> Stream for FramedRead<T> +where + T: AsyncRead + Unpin, +{ + type Item = Result<Frame, Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let span = tracing::trace_span!("FramedRead::poll_next"); + let _e = span.enter(); + loop { + tracing::trace!("poll"); + let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(bytes)) => bytes, + Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))), + None => return Poll::Ready(None), + }; + + tracing::trace!(read.bytes = bytes.len()); + let Self { + ref mut hpack, + max_header_list_size, + ref mut partial, + .. + } = *self; + if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? { + tracing::debug!(?frame, "received"); + return Poll::Ready(Some(Ok(frame))); + } + } + } +} + +fn map_err(err: io::Error) -> Error { + if let io::ErrorKind::InvalidData = err.kind() { + if let Some(custom) = err.get_ref() { + if custom.is::<LengthDelimitedCodecError>() { + return Error::library_go_away(Reason::FRAME_SIZE_ERROR); + } + } + } + err.into() +} + +// ===== impl Continuable ===== + +impl Continuable { + fn stream_id(&self) -> frame::StreamId { + match *self { + Continuable::Headers(ref h) => h.stream_id(), + Continuable::PushPromise(ref p) => p.stream_id(), + } + } + + fn is_over_size(&self) -> bool { + match *self { + Continuable::Headers(ref h) => h.is_over_size(), + Continuable::PushPromise(ref p) => p.is_over_size(), + } + } + + fn load_hpack( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), frame::Error> { + match *self { + Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder), + Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder), + } + } +} + +impl<T> From<Continuable> for Frame<T> { + fn from(cont: Continuable) -> Self { + match cont { + Continuable::Headers(mut headers) => { + headers.set_end_headers(); + headers.into() + } + Continuable::PushPromise(mut push) => { + push.set_end_headers(); + push.into() + } + } + } +} diff --git a/third_party/rust/h2/src/codec/framed_write.rs b/third_party/rust/h2/src/codec/framed_write.rs new file mode 100644 index 0000000000..4b1b4accc4 --- /dev/null +++ b/third_party/rust/h2/src/codec/framed_write.rs @@ -0,0 +1,374 @@ +use crate::codec::UserError; +use crate::codec::UserError::*; +use crate::frame::{self, Frame, FrameSize}; +use crate::hpack; + +use bytes::{Buf, BufMut, BytesMut}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::io::{self, Cursor, IoSlice}; + +// A macro to get around a method needing to borrow &mut self +macro_rules! limited_write_buf { + ($self:expr) => {{ + let limit = $self.max_frame_size() + frame::HEADER_LEN; + $self.buf.get_mut().limit(limit) + }}; +} + +#[derive(Debug)] +pub struct FramedWrite<T, B> { + /// Upstream `AsyncWrite` + inner: T, + + encoder: Encoder<B>, +} + +#[derive(Debug)] +struct Encoder<B> { + /// HPACK encoder + hpack: hpack::Encoder, + + /// Write buffer + /// + /// TODO: Should this be a ring buffer? + buf: Cursor<BytesMut>, + + /// Next frame to encode + next: Option<Next<B>>, + + /// Last data frame + last_data_frame: Option<frame::Data<B>>, + + /// Max frame size, this is specified by the peer + max_frame_size: FrameSize, + + /// Whether or not the wrapped `AsyncWrite` supports vectored IO. + is_write_vectored: bool, +} + +#[derive(Debug)] +enum Next<B> { + Data(frame::Data<B>), + Continuation(frame::Continuation), +} + +/// Initialize the connection with this amount of write buffer. +/// +/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS +/// frame that big. +const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; + +/// Min buffer required to attempt to write a frame +const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD; + +/// Chain payloads bigger than this. The remote will never advertise a max frame +/// size less than this (well, the spec says the max frame size can't be less +/// than 16kb, so not even close). +const CHAIN_THRESHOLD: usize = 256; + +// TODO: Make generic +impl<T, B> FramedWrite<T, B> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + pub fn new(inner: T) -> FramedWrite<T, B> { + let is_write_vectored = inner.is_write_vectored(); + FramedWrite { + inner, + encoder: Encoder { + hpack: hpack::Encoder::default(), + buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), + next: None, + last_data_frame: None, + max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, + is_write_vectored, + }, + } + } + + /// Returns `Ready` when `send` is able to accept a frame + /// + /// Calling this function may result in the current contents of the buffer + /// to be flushed to `T`. + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + if !self.encoder.has_capacity() { + // Try flushing + ready!(self.flush(cx))?; + + if !self.encoder.has_capacity() { + return Poll::Pending; + } + } + + Poll::Ready(Ok(())) + } + + /// Buffer a frame. + /// + /// `poll_ready` must be called first to ensure that a frame may be + /// accepted. + pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { + self.encoder.buffer(item) + } + + /// Flush buffered data to the wire + pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + let span = tracing::trace_span!("FramedWrite::flush"); + let _e = span.enter(); + + loop { + while !self.encoder.is_empty() { + match self.encoder.next { + Some(Next::Data(ref mut frame)) => { + tracing::trace!(queued_data_frame = true); + let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut buf, + cx, + ))? + } + _ => { + tracing::trace!(queued_data_frame = false); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut self.encoder.buf, + cx, + ))? + } + } + } + + match self.encoder.unset_frame() { + ControlFlow::Continue => (), + ControlFlow::Break => break, + } + } + + tracing::trace!("flushing buffer"); + // Flush the upstream + ready!(Pin::new(&mut self.inner).poll_flush(cx))?; + + Poll::Ready(Ok(())) + } + + /// Close the codec + pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + ready!(self.flush(cx))?; + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +fn write<T, B>( + writer: &mut T, + is_write_vectored: bool, + buf: &mut B, + cx: &mut Context<'_>, +) -> Poll<io::Result<()>> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + // TODO(eliza): when tokio-util 0.5.1 is released, this + // could just use `poll_write_buf`... + const MAX_IOVS: usize = 64; + let n = if is_write_vectored { + let mut bufs = [IoSlice::new(&[]); MAX_IOVS]; + let cnt = buf.chunks_vectored(&mut bufs); + ready!(Pin::new(writer).poll_write_vectored(cx, &bufs[..cnt]))? + } else { + ready!(Pin::new(writer).poll_write(cx, buf.chunk()))? + }; + buf.advance(n); + Ok(()).into() +} + +#[must_use] +enum ControlFlow { + Continue, + Break, +} + +impl<B> Encoder<B> +where + B: Buf, +{ + fn unset_frame(&mut self) -> ControlFlow { + // Clear internal buffer + self.buf.set_position(0); + self.buf.get_mut().clear(); + + // The data frame has been written, so unset it + match self.next.take() { + Some(Next::Data(frame)) => { + self.last_data_frame = Some(frame); + debug_assert!(self.is_empty()); + ControlFlow::Break + } + Some(Next::Continuation(frame)) => { + // Buffer the continuation frame, then try to write again + let mut buf = limited_write_buf!(self); + if let Some(continuation) = frame.encode(&mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + ControlFlow::Continue + } + None => ControlFlow::Break, + } + } + + fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { + // Ensure that we have enough capacity to accept the write. + assert!(self.has_capacity()); + let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); + let _e = span.enter(); + + tracing::debug!(frame = ?item, "send"); + + match item { + Frame::Data(mut v) => { + // Ensure that the payload is not greater than the max frame. + let len = v.payload().remaining(); + + if len > self.max_frame_size() { + return Err(PayloadTooBig); + } + + if len >= CHAIN_THRESHOLD { + let head = v.head(); + + // Encode the frame head to the buffer + head.encode(len, self.buf.get_mut()); + + // Save the data frame + self.next = Some(Next::Data(v)); + } else { + v.encode_chunk(self.buf.get_mut()); + + // The chunk has been fully encoded, so there is no need to + // keep it around + assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); + + // Save off the last frame... + self.last_data_frame = Some(v); + } + } + Frame::Headers(v) => { + let mut buf = limited_write_buf!(self); + if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + } + Frame::PushPromise(v) => { + let mut buf = limited_write_buf!(self); + if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + } + Frame::Settings(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded settings"); + } + Frame::GoAway(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); + } + Frame::Ping(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded ping"); + } + Frame::WindowUpdate(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); + } + + Frame::Priority(_) => { + /* + v.encode(self.buf.get_mut()); + tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); + */ + unimplemented!(); + } + Frame::Reset(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded reset"); + } + } + + Ok(()) + } + + fn has_capacity(&self) -> bool { + self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY + } + + fn is_empty(&self) -> bool { + match self.next { + Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), + _ => !self.buf.has_remaining(), + } + } +} + +impl<B> Encoder<B> { + fn max_frame_size(&self) -> usize { + self.max_frame_size as usize + } +} + +impl<T, B> FramedWrite<T, B> { + /// Returns the max frame size that can be sent + pub fn max_frame_size(&self) -> usize { + self.encoder.max_frame_size() + } + + /// Set the peer's max frame size. + pub fn set_max_frame_size(&mut self, val: usize) { + assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); + self.encoder.max_frame_size = val as FrameSize; + } + + /// Set the peer's header table size. + pub fn set_header_table_size(&mut self, val: usize) { + self.encoder.hpack.update_max_size(val); + } + + /// Retrieve the last data frame that has been sent + pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { + self.encoder.last_data_frame.take() + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +// We never project the Pin to `B`. +impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} + +#[cfg(feature = "unstable")] +mod unstable { + use super::*; + + impl<T, B> FramedWrite<T, B> { + pub fn get_ref(&self) -> &T { + &self.inner + } + } +} diff --git a/third_party/rust/h2/src/codec/mod.rs b/third_party/rust/h2/src/codec/mod.rs new file mode 100644 index 0000000000..359adf6e47 --- /dev/null +++ b/third_party/rust/h2/src/codec/mod.rs @@ -0,0 +1,201 @@ +mod error; +mod framed_read; +mod framed_write; + +pub use self::error::{SendError, UserError}; + +use self::framed_read::FramedRead; +use self::framed_write::FramedWrite; + +use crate::frame::{self, Data, Frame}; +use crate::proto::Error; + +use bytes::Buf; +use futures_core::Stream; +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::length_delimited; + +use std::io; + +#[derive(Debug)] +pub struct Codec<T, B> { + inner: FramedRead<FramedWrite<T, B>>, +} + +impl<T, B> Codec<T, B> +where + T: AsyncRead + AsyncWrite + Unpin, + B: Buf, +{ + /// Returns a new `Codec` with the default max frame size + #[inline] + pub fn new(io: T) -> Self { + Self::with_max_recv_frame_size(io, frame::DEFAULT_MAX_FRAME_SIZE as usize) + } + + /// Returns a new `Codec` with the given maximum frame size + pub fn with_max_recv_frame_size(io: T, max_frame_size: usize) -> Self { + // Wrap with writer + let framed_write = FramedWrite::new(io); + + // Delimit the frames + let delimited = length_delimited::Builder::new() + .big_endian() + .length_field_length(3) + .length_adjustment(9) + .num_skip(0) // Don't skip the header + .new_read(framed_write); + + let mut inner = FramedRead::new(delimited); + + // Use FramedRead's method since it checks the value is within range. + inner.set_max_frame_size(max_frame_size); + + Codec { inner } + } +} + +impl<T, B> Codec<T, B> { + /// Updates the max received frame size. + /// + /// The change takes effect the next time a frame is decoded. In other + /// words, if a frame is currently in process of being decoded with a frame + /// size greater than `val` but less than the max frame size in effect + /// before calling this function, then the frame will be allowed. + #[inline] + pub fn set_max_recv_frame_size(&mut self, val: usize) { + self.inner.set_max_frame_size(val) + } + + /// Returns the current max received frame size setting. + /// + /// This is the largest size this codec will accept from the wire. Larger + /// frames will be rejected. + #[cfg(feature = "unstable")] + #[inline] + pub fn max_recv_frame_size(&self) -> usize { + self.inner.max_frame_size() + } + + /// Returns the max frame size that can be sent to the peer. + pub fn max_send_frame_size(&self) -> usize { + self.inner.get_ref().max_frame_size() + } + + /// Set the peer's max frame size. + pub fn set_max_send_frame_size(&mut self, val: usize) { + self.framed_write().set_max_frame_size(val) + } + + /// Set the peer's header table size size. + pub fn set_send_header_table_size(&mut self, val: usize) { + self.framed_write().set_header_table_size(val) + } + + /// Set the max header list size that can be received. + pub fn set_max_recv_header_list_size(&mut self, val: usize) { + self.inner.set_max_header_list_size(val); + } + + /// Get a reference to the inner stream. + #[cfg(feature = "unstable")] + pub fn get_ref(&self) -> &T { + self.inner.get_ref().get_ref() + } + + /// Get a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut().get_mut() + } + + /// Takes the data payload value that was fully written to the socket + pub(crate) fn take_last_data_frame(&mut self) -> Option<Data<B>> { + self.framed_write().take_last_data_frame() + } + + fn framed_write(&mut self) -> &mut FramedWrite<T, B> { + self.inner.get_mut() + } +} + +impl<T, B> Codec<T, B> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + /// Returns `Ready` when the codec can buffer a frame + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + self.framed_write().poll_ready(cx) + } + + /// Buffer a frame. + /// + /// `poll_ready` must be called first to ensure that a frame may be + /// accepted. + /// + /// TODO: Rename this to avoid conflicts with Sink::buffer + pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { + self.framed_write().buffer(item) + } + + /// Flush buffered data to the wire + pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + self.framed_write().flush(cx) + } + + /// Shutdown the send half + pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + self.framed_write().shutdown(cx) + } +} + +impl<T, B> Stream for Codec<T, B> +where + T: AsyncRead + Unpin, +{ + type Item = Result<Frame, Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +impl<T, B> Sink<Frame<B>> for Codec<T, B> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + type Error = SendError; + + fn start_send(mut self: Pin<&mut Self>, item: Frame<B>) -> Result<(), Self::Error> { + Codec::buffer(&mut self, item)?; + Ok(()) + } + /// Returns `Ready` when the codec can buffer a frame + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.framed_write().poll_ready(cx).map_err(Into::into) + } + + /// Flush buffered data to the wire + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.framed_write().flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.shutdown(cx))?; + Poll::Ready(Ok(())) + } +} + +// TODO: remove (or improve) this +impl<T> From<T> for Codec<T, bytes::Bytes> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn from(src: T) -> Self { + Self::new(src) + } +} |