use crate::frame::{self, Frame, Kind, Reason}; use crate::frame::{ DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE, }; use crate::proto::Error; use crate::hpack; use futures_core::Stream; use bytes::BytesMut; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; use tokio_util::codec::FramedRead as InnerFramedRead; use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError}; // 16 MB "sane default" taken from golang http2 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; #[derive(Debug)] pub struct FramedRead { inner: InnerFramedRead, // hpack decoder state hpack: hpack::Decoder, max_header_list_size: usize, partial: Option, } /// Partially loaded headers frame #[derive(Debug)] struct Partial { /// Empty frame frame: Continuable, /// Partial header payload buf: BytesMut, } #[derive(Debug)] enum Continuable { Headers(frame::Headers), PushPromise(frame::PushPromise), } impl FramedRead { pub fn new(inner: InnerFramedRead) -> FramedRead { FramedRead { inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, partial: None, } } pub fn get_ref(&self) -> &T { self.inner.get_ref() } pub fn get_mut(&mut self) -> &mut T { self.inner.get_mut() } /// Returns the current max frame size setting #[cfg(feature = "unstable")] #[inline] pub fn max_frame_size(&self) -> usize { self.inner.decoder().max_frame_length() } /// Updates the max frame size setting. /// /// Must be within 16,384 and 16,777,215. #[inline] pub fn set_max_frame_size(&mut self, val: usize) { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); self.inner.decoder_mut().set_max_frame_length(val) } /// Update the max header list size setting. #[inline] pub fn set_max_header_list_size(&mut self, val: usize) { self.max_header_list_size = val; } } /// Decodes a frame. /// /// This method is intentionally de-generified and outlined because it is very large. fn decode_frame( hpack: &mut hpack::Decoder, max_header_list_size: usize, partial_inout: &mut Option, mut bytes: BytesMut, ) -> Result, Error> { let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); let _e = span.enter(); tracing::trace!("decoding frame from {}B", bytes.len()); // Parse the head let head = frame::Head::parse(&bytes); if partial_inout.is_some() && head.kind() != Kind::Continuation { proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); } let kind = head.kind(); tracing::trace!(frame.kind = ?kind); macro_rules! header_block { ($frame:ident, $head:ident, $bytes:ident) => ({ // Drop the frame header // TODO: Change to drain: carllerche/bytes#130 let _ = $bytes.split_to(frame::HEADER_LEN); // Parse the header frame w/o parsing the payload let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { Ok(res) => res, Err(frame::Error::InvalidDependencyId) => { proto_err!(stream: "invalid HEADERS dependency ID"); // A stream cannot depend on itself. An endpoint MUST // treat this as a stream error (Section 5.4.2) of type // `PROTOCOL_ERROR`. return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR)); }, Err(e) => { proto_err!(conn: "failed to load frame; err={:?}", e); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; let is_end_headers = frame.is_end_headers(); // Load the HPACK encoded headers match frame.load_hpack(&mut payload, max_header_list_size, hpack) { Ok(_) => {}, Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, Err(frame::Error::MalformedMessage) => { let id = $head.stream_id(); proto_err!(stream: "malformed header block; stream={:?}", id); return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); }, Err(e) => { proto_err!(conn: "failed HPACK decoding; err={:?}", e); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } } if is_end_headers { frame.into() } else { tracing::trace!("loaded partial header block"); // Defer returning the frame *partial_inout = Some(Partial { frame: Continuable::$frame(frame), buf: payload, }); return Ok(None); } }); } let frame = match kind { Kind::Settings => { let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::Ping => { let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load PING frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::WindowUpdate => { let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::Data => { let _ = bytes.split_to(frame::HEADER_LEN); let res = frame::Data::load(head, bytes.freeze()); // TODO: Should this always be connection level? Probably not... res.map_err(|e| { proto_err!(conn: "failed to load DATA frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::Headers => header_block!(Headers, head, bytes), Kind::Reset => { let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load RESET frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::GoAway => { let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); Error::library_go_away(Reason::PROTOCOL_ERROR) })? .into() } Kind::PushPromise => header_block!(PushPromise, head, bytes), Kind::Priority => { if head.stream_id() == 0 { // Invalid stream identifier proto_err!(conn: "invalid stream ID 0"); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); } match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { Ok(frame) => frame.into(), Err(frame::Error::InvalidDependencyId) => { // A stream cannot depend on itself. An endpoint MUST // treat this as a stream error (Section 5.4.2) of type // `PROTOCOL_ERROR`. let id = head.stream_id(); proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); } Err(e) => { proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } } } Kind::Continuation => { let is_end_headers = (head.flag() & 0x4) == 0x4; let mut partial = match partial_inout.take() { Some(partial) => partial, None => { proto_err!(conn: "received unexpected CONTINUATION frame"); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); } }; // The stream identifiers must match if partial.frame.stream_id() != head.stream_id() { proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); } // Extend the buf if partial.buf.is_empty() { partial.buf = bytes.split_off(frame::HEADER_LEN); } else { if partial.frame.is_over_size() { // If there was left over bytes previously, they may be // needed to continue decoding, even though we will // be ignoring this frame. This is done to keep the HPACK // decoder state up-to-date. // // Still, we need to be careful, because if a malicious // attacker were to try to send a gigantic string, such // that it fits over multiple header blocks, we could // grow memory uncontrollably again, and that'd be a shame. // // Instead, we use a simple heuristic to determine if // we should continue to ignore decoding, or to tell // the attacker to go away. if partial.buf.len() + bytes.len() > max_header_list_size { proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); return Err(Error::library_go_away(Reason::COMPRESSION_ERROR).into()); } } partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); } match partial .frame .load_hpack(&mut partial.buf, max_header_list_size, hpack) { Ok(_) => {} Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {} Err(frame::Error::MalformedMessage) => { let id = head.stream_id(); proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); } Err(e) => { proto_err!(conn: "failed HPACK decoding; err={:?}", e); return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } } if is_end_headers { partial.frame.into() } else { *partial_inout = Some(partial); return Ok(None); } } Kind::Unknown => { // Unknown frames are ignored return Ok(None); } }; Ok(Some(frame)) } impl Stream for FramedRead where T: AsyncRead + Unpin, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let span = tracing::trace_span!("FramedRead::poll_next"); let _e = span.enter(); loop { tracing::trace!("poll"); let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) { Some(Ok(bytes)) => bytes, Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))), None => return Poll::Ready(None), }; tracing::trace!(read.bytes = bytes.len()); let Self { ref mut hpack, max_header_list_size, ref mut partial, .. } = *self; if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? { tracing::debug!(?frame, "received"); return Poll::Ready(Some(Ok(frame))); } } } } fn map_err(err: io::Error) -> Error { if let io::ErrorKind::InvalidData = err.kind() { if let Some(custom) = err.get_ref() { if custom.is::() { return Error::library_go_away(Reason::FRAME_SIZE_ERROR); } } } err.into() } // ===== impl Continuable ===== impl Continuable { fn stream_id(&self) -> frame::StreamId { match *self { Continuable::Headers(ref h) => h.stream_id(), Continuable::PushPromise(ref p) => p.stream_id(), } } fn is_over_size(&self) -> bool { match *self { Continuable::Headers(ref h) => h.is_over_size(), Continuable::PushPromise(ref p) => p.is_over_size(), } } fn load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error> { match *self { Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder), Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder), } } } impl From for Frame { fn from(cont: Continuable) -> Self { match cont { Continuable::Headers(mut headers) => { headers.set_end_headers(); headers.into() } Continuable::PushPromise(mut push) => { push.set_end_headers(); push.into() } } } }