summaryrefslogtreecommitdiffstats
path: root/third_party/rust/h2/src/codec/framed_read.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/h2/src/codec/framed_read.rs')
-rw-r--r--third_party/rust/h2/src/codec/framed_read.rs421
1 files changed, 421 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/codec/framed_read.rs b/third_party/rust/h2/src/codec/framed_read.rs
new file mode 100644
index 0000000000..3b0030d930
--- /dev/null
+++ b/third_party/rust/h2/src/codec/framed_read.rs
@@ -0,0 +1,421 @@
+use crate::frame::{self, Frame, Kind, Reason};
+use crate::frame::{
+ DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
+};
+use crate::proto::Error;
+
+use crate::hpack;
+
+use futures_core::Stream;
+
+use bytes::BytesMut;
+
+use std::io;
+
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+use tokio_util::codec::FramedRead as InnerFramedRead;
+use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
+
+// 16 MB "sane default" taken from golang http2
+const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
+
+#[derive(Debug)]
+pub struct FramedRead<T> {
+ inner: InnerFramedRead<T, LengthDelimitedCodec>,
+
+ // hpack decoder state
+ hpack: hpack::Decoder,
+
+ max_header_list_size: usize,
+
+ partial: Option<Partial>,
+}
+
+/// Partially loaded headers frame
+#[derive(Debug)]
+struct Partial {
+ /// Empty frame
+ frame: Continuable,
+
+ /// Partial header payload
+ buf: BytesMut,
+}
+
+#[derive(Debug)]
+enum Continuable {
+ Headers(frame::Headers),
+ PushPromise(frame::PushPromise),
+}
+
+impl<T> FramedRead<T> {
+ pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
+ FramedRead {
+ inner,
+ hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
+ max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
+ partial: None,
+ }
+ }
+
+ pub fn get_ref(&self) -> &T {
+ self.inner.get_ref()
+ }
+
+ pub fn get_mut(&mut self) -> &mut T {
+ self.inner.get_mut()
+ }
+
+ /// Returns the current max frame size setting
+ #[cfg(feature = "unstable")]
+ #[inline]
+ pub fn max_frame_size(&self) -> usize {
+ self.inner.decoder().max_frame_length()
+ }
+
+ /// Updates the max frame size setting.
+ ///
+ /// Must be within 16,384 and 16,777,215.
+ #[inline]
+ pub fn set_max_frame_size(&mut self, val: usize) {
+ assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
+ self.inner.decoder_mut().set_max_frame_length(val)
+ }
+
+ /// Update the max header list size setting.
+ #[inline]
+ pub fn set_max_header_list_size(&mut self, val: usize) {
+ self.max_header_list_size = val;
+ }
+
+ /// Update the header table size setting.
+ #[inline]
+ pub fn set_header_table_size(&mut self, val: usize) {
+ self.hpack.queue_size_update(val);
+ }
+}
+
+/// Decodes a frame.
+///
+/// This method is intentionally de-generified and outlined because it is very large.
+fn decode_frame(
+ hpack: &mut hpack::Decoder,
+ max_header_list_size: usize,
+ partial_inout: &mut Option<Partial>,
+ mut bytes: BytesMut,
+) -> Result<Option<Frame>, Error> {
+ let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
+ let _e = span.enter();
+
+ tracing::trace!("decoding frame from {}B", bytes.len());
+
+ // Parse the head
+ let head = frame::Head::parse(&bytes);
+
+ if partial_inout.is_some() && head.kind() != Kind::Continuation {
+ proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+
+ let kind = head.kind();
+
+ tracing::trace!(frame.kind = ?kind);
+
+ macro_rules! header_block {
+ ($frame:ident, $head:ident, $bytes:ident) => ({
+ // Drop the frame header
+ // TODO: Change to drain: carllerche/bytes#130
+ let _ = $bytes.split_to(frame::HEADER_LEN);
+
+ // Parse the header frame w/o parsing the payload
+ let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
+ Ok(res) => res,
+ Err(frame::Error::InvalidDependencyId) => {
+ proto_err!(stream: "invalid HEADERS dependency ID");
+ // A stream cannot depend on itself. An endpoint MUST
+ // treat this as a stream error (Section 5.4.2) of type
+ // `PROTOCOL_ERROR`.
+ return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
+ },
+ Err(e) => {
+ proto_err!(conn: "failed to load frame; err={:?}", e);
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+ };
+
+ let is_end_headers = frame.is_end_headers();
+
+ // Load the HPACK encoded headers
+ match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
+ Ok(_) => {},
+ Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
+ Err(frame::Error::MalformedMessage) => {
+ let id = $head.stream_id();
+ proto_err!(stream: "malformed header block; stream={:?}", id);
+ return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
+ },
+ Err(e) => {
+ proto_err!(conn: "failed HPACK decoding; err={:?}", e);
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+ }
+
+ if is_end_headers {
+ frame.into()
+ } else {
+ tracing::trace!("loaded partial header block");
+ // Defer returning the frame
+ *partial_inout = Some(Partial {
+ frame: Continuable::$frame(frame),
+ buf: payload,
+ });
+
+ return Ok(None);
+ }
+ });
+ }
+
+ let frame = match kind {
+ Kind::Settings => {
+ let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
+
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::Ping => {
+ let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
+
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load PING frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::WindowUpdate => {
+ let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
+
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::Data => {
+ let _ = bytes.split_to(frame::HEADER_LEN);
+ let res = frame::Data::load(head, bytes.freeze());
+
+ // TODO: Should this always be connection level? Probably not...
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load DATA frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::Headers => header_block!(Headers, head, bytes),
+ Kind::Reset => {
+ let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load RESET frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::GoAway => {
+ let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
+ res.map_err(|e| {
+ proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
+ Error::library_go_away(Reason::PROTOCOL_ERROR)
+ })?
+ .into()
+ }
+ Kind::PushPromise => header_block!(PushPromise, head, bytes),
+ Kind::Priority => {
+ if head.stream_id() == 0 {
+ // Invalid stream identifier
+ proto_err!(conn: "invalid stream ID 0");
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+
+ match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
+ Ok(frame) => frame.into(),
+ Err(frame::Error::InvalidDependencyId) => {
+ // A stream cannot depend on itself. An endpoint MUST
+ // treat this as a stream error (Section 5.4.2) of type
+ // `PROTOCOL_ERROR`.
+ let id = head.stream_id();
+ proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
+ return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
+ }
+ Err(e) => {
+ proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+ }
+ }
+ Kind::Continuation => {
+ let is_end_headers = (head.flag() & 0x4) == 0x4;
+
+ let mut partial = match partial_inout.take() {
+ Some(partial) => partial,
+ None => {
+ proto_err!(conn: "received unexpected CONTINUATION frame");
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+ };
+
+ // The stream identifiers must match
+ if partial.frame.stream_id() != head.stream_id() {
+ proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+
+ // Extend the buf
+ if partial.buf.is_empty() {
+ partial.buf = bytes.split_off(frame::HEADER_LEN);
+ } else {
+ if partial.frame.is_over_size() {
+ // If there was left over bytes previously, they may be
+ // needed to continue decoding, even though we will
+ // be ignoring this frame. This is done to keep the HPACK
+ // decoder state up-to-date.
+ //
+ // Still, we need to be careful, because if a malicious
+ // attacker were to try to send a gigantic string, such
+ // that it fits over multiple header blocks, we could
+ // grow memory uncontrollably again, and that'd be a shame.
+ //
+ // Instead, we use a simple heuristic to determine if
+ // we should continue to ignore decoding, or to tell
+ // the attacker to go away.
+ if partial.buf.len() + bytes.len() > max_header_list_size {
+ proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
+ return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
+ }
+ }
+ partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
+ }
+
+ match partial
+ .frame
+ .load_hpack(&mut partial.buf, max_header_list_size, hpack)
+ {
+ Ok(_) => {}
+ Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
+ Err(frame::Error::MalformedMessage) => {
+ let id = head.stream_id();
+ proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
+ return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
+ }
+ Err(e) => {
+ proto_err!(conn: "failed HPACK decoding; err={:?}", e);
+ return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
+ }
+ }
+
+ if is_end_headers {
+ partial.frame.into()
+ } else {
+ *partial_inout = Some(partial);
+ return Ok(None);
+ }
+ }
+ Kind::Unknown => {
+ // Unknown frames are ignored
+ return Ok(None);
+ }
+ };
+
+ Ok(Some(frame))
+}
+
+impl<T> Stream for FramedRead<T>
+where
+ T: AsyncRead + Unpin,
+{
+ type Item = Result<Frame, Error>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let span = tracing::trace_span!("FramedRead::poll_next");
+ let _e = span.enter();
+ loop {
+ tracing::trace!("poll");
+ let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
+ Some(Ok(bytes)) => bytes,
+ Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
+ None => return Poll::Ready(None),
+ };
+
+ tracing::trace!(read.bytes = bytes.len());
+ let Self {
+ ref mut hpack,
+ max_header_list_size,
+ ref mut partial,
+ ..
+ } = *self;
+ if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? {
+ tracing::debug!(?frame, "received");
+ return Poll::Ready(Some(Ok(frame)));
+ }
+ }
+ }
+}
+
+fn map_err(err: io::Error) -> Error {
+ if let io::ErrorKind::InvalidData = err.kind() {
+ if let Some(custom) = err.get_ref() {
+ if custom.is::<LengthDelimitedCodecError>() {
+ return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
+ }
+ }
+ }
+ err.into()
+}
+
+// ===== impl Continuable =====
+
+impl Continuable {
+ fn stream_id(&self) -> frame::StreamId {
+ match *self {
+ Continuable::Headers(ref h) => h.stream_id(),
+ Continuable::PushPromise(ref p) => p.stream_id(),
+ }
+ }
+
+ fn is_over_size(&self) -> bool {
+ match *self {
+ Continuable::Headers(ref h) => h.is_over_size(),
+ Continuable::PushPromise(ref p) => p.is_over_size(),
+ }
+ }
+
+ fn load_hpack(
+ &mut self,
+ src: &mut BytesMut,
+ max_header_list_size: usize,
+ decoder: &mut hpack::Decoder,
+ ) -> Result<(), frame::Error> {
+ match *self {
+ Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
+ Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
+ }
+ }
+}
+
+impl<T> From<Continuable> for Frame<T> {
+ fn from(cont: Continuable) -> Self {
+ match cont {
+ Continuable::Headers(mut headers) => {
+ headers.set_end_headers();
+ headers.into()
+ }
+ Continuable::PushPromise(mut push) => {
+ push.set_end_headers();
+ push.into()
+ }
+ }
+ }
+}