diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/h2/src/frame | |
parent | Initial commit. (diff) | |
download | firefox-esr-upstream.tar.xz firefox-esr-upstream.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/h2/src/frame')
-rw-r--r-- | third_party/rust/h2/src/frame/data.rs | 233 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/go_away.rs | 79 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/head.rs | 94 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/headers.rs | 1039 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/mod.rs | 171 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/ping.rs | 102 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/priority.rs | 72 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/reason.rs | 134 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/reset.rs | 56 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/settings.rs | 391 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/stream_id.rs | 96 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/util.rs | 79 | ||||
-rw-r--r-- | third_party/rust/h2/src/frame/window_update.rs | 62 |
13 files changed, 2608 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/frame/data.rs b/third_party/rust/h2/src/frame/data.rs new file mode 100644 index 0000000000..e253d5e23d --- /dev/null +++ b/third_party/rust/h2/src/frame/data.rs @@ -0,0 +1,233 @@ +use crate::frame::{util, Error, Frame, Head, Kind, StreamId}; +use bytes::{Buf, BufMut, Bytes}; + +use std::fmt; + +/// Data frame +/// +/// Data frames convey arbitrary, variable-length sequences of octets associated +/// with a stream. One or more DATA frames are used, for instance, to carry HTTP +/// request or response payloads. +#[derive(Eq, PartialEq)] +pub struct Data<T = Bytes> { + stream_id: StreamId, + data: T, + flags: DataFlags, + pad_len: Option<u8>, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +struct DataFlags(u8); + +const END_STREAM: u8 = 0x1; +const PADDED: u8 = 0x8; +const ALL: u8 = END_STREAM | PADDED; + +impl<T> Data<T> { + /// Creates a new DATA frame. + pub fn new(stream_id: StreamId, payload: T) -> Self { + assert!(!stream_id.is_zero()); + + Data { + stream_id, + data: payload, + flags: DataFlags::default(), + pad_len: None, + } + } + + /// Returns the stream identifier that this frame is associated with. + /// + /// This cannot be a zero stream identifier. + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + /// Gets the value of the `END_STREAM` flag for this frame. + /// + /// If true, this frame is the last that the endpoint will send for the + /// identified stream. + /// + /// Setting this flag causes the stream to enter one of the "half-closed" + /// states or the "closed" state (Section 5.1). + pub fn is_end_stream(&self) -> bool { + self.flags.is_end_stream() + } + + /// Sets the value for the `END_STREAM` flag on this frame. + pub fn set_end_stream(&mut self, val: bool) { + if val { + self.flags.set_end_stream(); + } else { + self.flags.unset_end_stream(); + } + } + + /// Returns whether the `PADDED` flag is set on this frame. + #[cfg(feature = "unstable")] + pub fn is_padded(&self) -> bool { + self.flags.is_padded() + } + + /// Sets the value for the `PADDED` flag on this frame. + #[cfg(feature = "unstable")] + pub fn set_padded(&mut self) { + self.flags.set_padded(); + } + + /// Returns a reference to this frame's payload. + /// + /// This does **not** include any padding that might have been originally + /// included. + pub fn payload(&self) -> &T { + &self.data + } + + /// Returns a mutable reference to this frame's payload. + /// + /// This does **not** include any padding that might have been originally + /// included. + pub fn payload_mut(&mut self) -> &mut T { + &mut self.data + } + + /// Consumes `self` and returns the frame's payload. + /// + /// This does **not** include any padding that might have been originally + /// included. + pub fn into_payload(self) -> T { + self.data + } + + pub(crate) fn head(&self) -> Head { + Head::new(Kind::Data, self.flags.into(), self.stream_id) + } + + pub(crate) fn map<F, U>(self, f: F) -> Data<U> + where + F: FnOnce(T) -> U, + { + Data { + stream_id: self.stream_id, + data: f(self.data), + flags: self.flags, + pad_len: self.pad_len, + } + } +} + +impl Data<Bytes> { + pub(crate) fn load(head: Head, mut payload: Bytes) -> Result<Self, Error> { + let flags = DataFlags::load(head.flag()); + + // The stream identifier must not be zero + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + let pad_len = if flags.is_padded() { + let len = util::strip_padding(&mut payload)?; + Some(len) + } else { + None + }; + + Ok(Data { + stream_id: head.stream_id(), + data: payload, + flags, + pad_len, + }) + } +} + +impl<T: Buf> Data<T> { + /// Encode the data frame into the `dst` buffer. + /// + /// # Panics + /// + /// Panics if `dst` cannot contain the data frame. + pub(crate) fn encode_chunk<U: BufMut>(&mut self, dst: &mut U) { + let len = self.data.remaining() as usize; + + assert!(dst.remaining_mut() >= len); + + self.head().encode(len, dst); + dst.put(&mut self.data); + } +} + +impl<T> From<Data<T>> for Frame<T> { + fn from(src: Data<T>) -> Self { + Frame::Data(src) + } +} + +impl<T> fmt::Debug for Data<T> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut f = fmt.debug_struct("Data"); + f.field("stream_id", &self.stream_id); + if !self.flags.is_empty() { + f.field("flags", &self.flags); + } + if let Some(ref pad_len) = self.pad_len { + f.field("pad_len", pad_len); + } + // `data` bytes purposefully excluded + f.finish() + } +} + +// ===== impl DataFlags ===== + +impl DataFlags { + fn load(bits: u8) -> DataFlags { + DataFlags(bits & ALL) + } + + fn is_empty(&self) -> bool { + self.0 == 0 + } + + fn is_end_stream(&self) -> bool { + self.0 & END_STREAM == END_STREAM + } + + fn set_end_stream(&mut self) { + self.0 |= END_STREAM + } + + fn unset_end_stream(&mut self) { + self.0 &= !END_STREAM + } + + fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } + + #[cfg(feature = "unstable")] + fn set_padded(&mut self) { + self.0 |= PADDED + } +} + +impl Default for DataFlags { + fn default() -> Self { + DataFlags(0) + } +} + +impl From<DataFlags> for u8 { + fn from(src: DataFlags) -> u8 { + src.0 + } +} + +impl fmt::Debug for DataFlags { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(fmt, self.0) + .flag_if(self.is_end_stream(), "END_STREAM") + .flag_if(self.is_padded(), "PADDED") + .finish() + } +} diff --git a/third_party/rust/h2/src/frame/go_away.rs b/third_party/rust/h2/src/frame/go_away.rs new file mode 100644 index 0000000000..91d9c4c6b5 --- /dev/null +++ b/third_party/rust/h2/src/frame/go_away.rs @@ -0,0 +1,79 @@ +use std::fmt; + +use bytes::{BufMut, Bytes}; + +use crate::frame::{self, Error, Head, Kind, Reason, StreamId}; + +#[derive(Clone, Eq, PartialEq)] +pub struct GoAway { + last_stream_id: StreamId, + error_code: Reason, + #[allow(unused)] + debug_data: Bytes, +} + +impl GoAway { + pub fn new(last_stream_id: StreamId, reason: Reason) -> Self { + GoAway { + last_stream_id, + error_code: reason, + debug_data: Bytes::new(), + } + } + + pub fn last_stream_id(&self) -> StreamId { + self.last_stream_id + } + + pub fn reason(&self) -> Reason { + self.error_code + } + + pub fn debug_data(&self) -> &Bytes { + &self.debug_data + } + + pub fn load(payload: &[u8]) -> Result<GoAway, Error> { + if payload.len() < 8 { + return Err(Error::BadFrameSize); + } + + let (last_stream_id, _) = StreamId::parse(&payload[..4]); + let error_code = unpack_octets_4!(payload, 4, u32); + let debug_data = Bytes::copy_from_slice(&payload[8..]); + + Ok(GoAway { + last_stream_id, + error_code: error_code.into(), + debug_data, + }) + } + + pub fn encode<B: BufMut>(&self, dst: &mut B) { + tracing::trace!("encoding GO_AWAY; code={:?}", self.error_code); + let head = Head::new(Kind::GoAway, 0, StreamId::zero()); + head.encode(8, dst); + dst.put_u32(self.last_stream_id.into()); + dst.put_u32(self.error_code.into()); + } +} + +impl<B> From<GoAway> for frame::Frame<B> { + fn from(src: GoAway) -> Self { + frame::Frame::GoAway(src) + } +} + +impl fmt::Debug for GoAway { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("GoAway"); + builder.field("error_code", &self.error_code); + builder.field("last_stream_id", &self.last_stream_id); + + if !self.debug_data.is_empty() { + builder.field("debug_data", &self.debug_data); + } + + builder.finish() + } +} diff --git a/third_party/rust/h2/src/frame/head.rs b/third_party/rust/h2/src/frame/head.rs new file mode 100644 index 0000000000..38be2f6973 --- /dev/null +++ b/third_party/rust/h2/src/frame/head.rs @@ -0,0 +1,94 @@ +use super::StreamId; + +use bytes::BufMut; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct Head { + kind: Kind, + flag: u8, + stream_id: StreamId, +} + +#[repr(u8)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Kind { + Data = 0, + Headers = 1, + Priority = 2, + Reset = 3, + Settings = 4, + PushPromise = 5, + Ping = 6, + GoAway = 7, + WindowUpdate = 8, + Continuation = 9, + Unknown, +} + +// ===== impl Head ===== + +impl Head { + pub fn new(kind: Kind, flag: u8, stream_id: StreamId) -> Head { + Head { + kind, + flag, + stream_id, + } + } + + /// Parse an HTTP/2 frame header + pub fn parse(header: &[u8]) -> Head { + let (stream_id, _) = StreamId::parse(&header[5..]); + + Head { + kind: Kind::new(header[3]), + flag: header[4], + stream_id, + } + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn kind(&self) -> Kind { + self.kind + } + + pub fn flag(&self) -> u8 { + self.flag + } + + pub fn encode_len(&self) -> usize { + super::HEADER_LEN + } + + pub fn encode<T: BufMut>(&self, payload_len: usize, dst: &mut T) { + debug_assert!(self.encode_len() <= dst.remaining_mut()); + + dst.put_uint(payload_len as u64, 3); + dst.put_u8(self.kind as u8); + dst.put_u8(self.flag); + dst.put_u32(self.stream_id.into()); + } +} + +// ===== impl Kind ===== + +impl Kind { + pub fn new(byte: u8) -> Kind { + match byte { + 0 => Kind::Data, + 1 => Kind::Headers, + 2 => Kind::Priority, + 3 => Kind::Reset, + 4 => Kind::Settings, + 5 => Kind::PushPromise, + 6 => Kind::Ping, + 7 => Kind::GoAway, + 8 => Kind::WindowUpdate, + 9 => Kind::Continuation, + _ => Kind::Unknown, + } + } +} diff --git a/third_party/rust/h2/src/frame/headers.rs b/third_party/rust/h2/src/frame/headers.rs new file mode 100644 index 0000000000..bcb9050133 --- /dev/null +++ b/third_party/rust/h2/src/frame/headers.rs @@ -0,0 +1,1039 @@ +use super::{util, StreamDependency, StreamId}; +use crate::ext::Protocol; +use crate::frame::{Error, Frame, Head, Kind}; +use crate::hpack::{self, BytesStr}; + +use http::header::{self, HeaderName, HeaderValue}; +use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; + +use bytes::{BufMut, Bytes, BytesMut}; + +use std::fmt; +use std::io::Cursor; + +type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; +/// Header frame +/// +/// This could be either a request or a response. +#[derive(Eq, PartialEq)] +pub struct Headers { + /// The ID of the stream with which this frame is associated. + stream_id: StreamId, + + /// The stream dependency information, if any. + stream_dep: Option<StreamDependency>, + + /// The header block fragment + header_block: HeaderBlock, + + /// The associated flags + flags: HeadersFlag, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct HeadersFlag(u8); + +#[derive(Eq, PartialEq)] +pub struct PushPromise { + /// The ID of the stream with which this frame is associated. + stream_id: StreamId, + + /// The ID of the stream being reserved by this PushPromise. + promised_id: StreamId, + + /// The header block fragment + header_block: HeaderBlock, + + /// The associated flags + flags: PushPromiseFlag, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct PushPromiseFlag(u8); + +#[derive(Debug)] +pub struct Continuation { + /// Stream ID of continuation frame + stream_id: StreamId, + + header_block: EncodingHeaderBlock, +} + +// TODO: These fields shouldn't be `pub` +#[derive(Debug, Default, Eq, PartialEq)] +pub struct Pseudo { + // Request + pub method: Option<Method>, + pub scheme: Option<BytesStr>, + pub authority: Option<BytesStr>, + pub path: Option<BytesStr>, + pub protocol: Option<Protocol>, + + // Response + pub status: Option<StatusCode>, +} + +#[derive(Debug)] +pub struct Iter { + /// Pseudo headers + pseudo: Option<Pseudo>, + + /// Header fields + fields: header::IntoIter<HeaderValue>, +} + +#[derive(Debug, PartialEq, Eq)] +struct HeaderBlock { + /// The decoded header fields + fields: HeaderMap, + + /// Set to true if decoding went over the max header list size. + is_over_size: bool, + + /// Pseudo headers, these are broken out as they must be sent as part of the + /// headers frame. + pseudo: Pseudo, +} + +#[derive(Debug)] +struct EncodingHeaderBlock { + hpack: Bytes, +} + +const END_STREAM: u8 = 0x1; +const END_HEADERS: u8 = 0x4; +const PADDED: u8 = 0x8; +const PRIORITY: u8 = 0x20; +const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY; + +// ===== impl Headers ===== + +impl Headers { + /// Create a new HEADERS frame + pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self { + Headers { + stream_id, + stream_dep: None, + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo, + }, + flags: HeadersFlag::default(), + } + } + + pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self { + let mut flags = HeadersFlag::default(); + flags.set_end_stream(); + + Headers { + stream_id, + stream_dep: None, + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo: Pseudo::default(), + }, + flags, + } + } + + /// Loads the header frame but doesn't actually do HPACK decoding. + /// + /// HPACK decoding is done in the `load_hpack` step. + pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { + let flags = HeadersFlag(head.flag()); + let mut pad = 0; + + tracing::trace!("loading headers; flags={:?}", flags); + + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // Read the padding length + if flags.is_padded() { + if src.is_empty() { + return Err(Error::MalformedMessage); + } + pad = src[0] as usize; + + // Drop the padding + let _ = src.split_to(1); + } + + // Read the stream dependency + let stream_dep = if flags.is_priority() { + if src.len() < 5 { + return Err(Error::MalformedMessage); + } + let stream_dep = StreamDependency::load(&src[..5])?; + + if stream_dep.dependency_id() == head.stream_id() { + return Err(Error::InvalidDependencyId); + } + + // Drop the next 5 bytes + let _ = src.split_to(5); + + Some(stream_dep) + } else { + None + }; + + if pad > 0 { + if pad > src.len() { + return Err(Error::TooMuchPadding); + } + + let len = src.len() - pad; + src.truncate(len); + } + + let headers = Headers { + stream_id: head.stream_id(), + stream_dep, + header_block: HeaderBlock { + fields: HeaderMap::new(), + is_over_size: false, + pseudo: Pseudo::default(), + }, + flags, + }; + + Ok((headers, src)) + } + + pub fn load_hpack( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn is_end_headers(&self) -> bool { + self.flags.is_end_headers() + } + + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + + pub fn is_end_stream(&self) -> bool { + self.flags.is_end_stream() + } + + pub fn set_end_stream(&mut self) { + self.flags.set_end_stream() + } + + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.header_block.pseudo, self.header_block.fields) + } + + #[cfg(feature = "unstable")] + pub fn pseudo_mut(&mut self) -> &mut Pseudo { + &mut self.header_block.pseudo + } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.header_block.pseudo.is_informational() + } + + pub fn fields(&self) -> &HeaderMap { + &self.header_block.fields + } + + pub fn into_fields(self) -> HeaderMap { + self.header_block.fields + } + + pub fn encode( + self, + encoder: &mut hpack::Encoder, + dst: &mut EncodeBuf<'_>, + ) -> Option<Continuation> { + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + + // Get the HEADERS frame head + let head = self.head(); + + self.header_block + .into_encoding(encoder) + .encode(&head, dst, |_| {}) + } + + fn head(&self) -> Head { + Head::new(Kind::Headers, self.flags.into(), self.stream_id) + } +} + +impl<T> From<Headers> for Frame<T> { + fn from(src: Headers) -> Self { + Frame::Headers(src) + } +} + +impl fmt::Debug for Headers { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut builder = f.debug_struct("Headers"); + builder + .field("stream_id", &self.stream_id) + .field("flags", &self.flags); + + if let Some(ref protocol) = self.header_block.pseudo.protocol { + builder.field("protocol", protocol); + } + + if let Some(ref dep) = self.stream_dep { + builder.field("stream_dep", dep); + } + + // `fields` and `pseudo` purposefully not included + builder.finish() + } +} + +// ===== util ===== + +pub fn parse_u64(src: &[u8]) -> Result<u64, ()> { + if src.len() > 19 { + // At danger for overflow... + return Err(()); + } + + let mut ret = 0; + + for &d in src { + if d < b'0' || d > b'9' { + return Err(()); + } + + ret *= 10; + ret += (d - b'0') as u64; + } + + Ok(ret) +} + +// ===== impl PushPromise ===== + +#[derive(Debug)] +pub enum PushPromiseHeaderError { + InvalidContentLength(Result<u64, ()>), + NotSafeAndCacheable, +} + +impl PushPromise { + pub fn new( + stream_id: StreamId, + promised_id: StreamId, + pseudo: Pseudo, + fields: HeaderMap, + ) -> Self { + PushPromise { + flags: PushPromiseFlag::default(), + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo, + }, + promised_id, + stream_id, + } + } + + pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> { + use PushPromiseHeaderError::*; + // The spec has some requirements for promised request headers + // [https://httpwg.org/specs/rfc7540.html#PushRequests] + + // A promised request "that indicates the presence of a request body + // MUST reset the promised stream with a stream error" + if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) { + let parsed_length = parse_u64(content_length.as_bytes()); + if parsed_length != Ok(0) { + return Err(InvalidContentLength(parsed_length)); + } + } + // "The server MUST include a method in the :method pseudo-header field + // that is safe and cacheable" + if !Self::safe_and_cacheable(req.method()) { + return Err(NotSafeAndCacheable); + } + + Ok(()) + } + + fn safe_and_cacheable(method: &Method) -> bool { + // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods + // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods + return method == Method::GET || method == Method::HEAD; + } + + pub fn fields(&self) -> &HeaderMap { + &self.header_block.fields + } + + #[cfg(feature = "unstable")] + pub fn into_fields(self) -> HeaderMap { + self.header_block.fields + } + + /// Loads the push promise frame but doesn't actually do HPACK decoding. + /// + /// HPACK decoding is done in the `load_hpack` step. + pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { + let flags = PushPromiseFlag(head.flag()); + let mut pad = 0; + + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // Read the padding length + if flags.is_padded() { + if src.is_empty() { + return Err(Error::MalformedMessage); + } + + // TODO: Ensure payload is sized correctly + pad = src[0] as usize; + + // Drop the padding + let _ = src.split_to(1); + } + + if src.len() < 5 { + return Err(Error::MalformedMessage); + } + + let (promised_id, _) = StreamId::parse(&src[..4]); + // Drop promised_id bytes + let _ = src.split_to(4); + + if pad > 0 { + if pad > src.len() { + return Err(Error::TooMuchPadding); + } + + let len = src.len() - pad; + src.truncate(len); + } + + let frame = PushPromise { + flags, + header_block: HeaderBlock { + fields: HeaderMap::new(), + is_over_size: false, + pseudo: Pseudo::default(), + }, + promised_id, + stream_id: head.stream_id(), + }; + Ok((frame, src)) + } + + pub fn load_hpack( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn promised_id(&self) -> StreamId { + self.promised_id + } + + pub fn is_end_headers(&self) -> bool { + self.flags.is_end_headers() + } + + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + + pub fn encode( + self, + encoder: &mut hpack::Encoder, + dst: &mut EncodeBuf<'_>, + ) -> Option<Continuation> { + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + + let head = self.head(); + let promised_id = self.promised_id; + + self.header_block + .into_encoding(encoder) + .encode(&head, dst, |dst| { + dst.put_u32(promised_id.into()); + }) + } + + fn head(&self) -> Head { + Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) + } + + /// Consume `self`, returning the parts of the frame + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.header_block.pseudo, self.header_block.fields) + } +} + +impl<T> From<PushPromise> for Frame<T> { + fn from(src: PushPromise) -> Self { + Frame::PushPromise(src) + } +} + +impl fmt::Debug for PushPromise { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("PushPromise") + .field("stream_id", &self.stream_id) + .field("promised_id", &self.promised_id) + .field("flags", &self.flags) + // `fields` and `pseudo` purposefully not included + .finish() + } +} + +// ===== impl Continuation ===== + +impl Continuation { + fn head(&self) -> Head { + Head::new(Kind::Continuation, END_HEADERS, self.stream_id) + } + + pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> { + // Get the CONTINUATION frame head + let head = self.head(); + + self.header_block.encode(&head, dst, |_| {}) + } +} + +// ===== impl Pseudo ===== + +impl Pseudo { + pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { + let parts = uri::Parts::from(uri); + + let mut path = parts + .path_and_query + .map(|v| BytesStr::from(v.as_str())) + .unwrap_or(BytesStr::from_static("")); + + match method { + Method::OPTIONS | Method::CONNECT => {} + _ if path.is_empty() => { + path = BytesStr::from_static("/"); + } + _ => {} + } + + let mut pseudo = Pseudo { + method: Some(method), + scheme: None, + authority: None, + path: Some(path).filter(|p| !p.is_empty()), + protocol, + status: None, + }; + + // If the URI includes a scheme component, add it to the pseudo headers + // + // TODO: Scheme must be set... + if let Some(scheme) = parts.scheme { + pseudo.set_scheme(scheme); + } + + // If the URI includes an authority component, add it to the pseudo + // headers + if let Some(authority) = parts.authority { + pseudo.set_authority(BytesStr::from(authority.as_str())); + } + + pseudo + } + + pub fn response(status: StatusCode) -> Self { + Pseudo { + method: None, + scheme: None, + authority: None, + path: None, + protocol: None, + status: Some(status), + } + } + + #[cfg(feature = "unstable")] + pub fn set_status(&mut self, value: StatusCode) { + self.status = Some(value); + } + + pub fn set_scheme(&mut self, scheme: uri::Scheme) { + let bytes_str = match scheme.as_str() { + "http" => BytesStr::from_static("http"), + "https" => BytesStr::from_static("https"), + s => BytesStr::from(s), + }; + self.scheme = Some(bytes_str); + } + + #[cfg(feature = "unstable")] + pub fn set_protocol(&mut self, protocol: Protocol) { + self.protocol = Some(protocol); + } + + pub fn set_authority(&mut self, authority: BytesStr) { + self.authority = Some(authority); + } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.status + .map_or(false, |status| status.is_informational()) + } +} + +// ===== impl EncodingHeaderBlock ===== + +impl EncodingHeaderBlock { + fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> + where + F: FnOnce(&mut EncodeBuf<'_>), + { + let head_pos = dst.get_ref().len(); + + // At this point, we don't know how big the h2 frame will be. + // So, we write the head with length 0, then write the body, and + // finally write the length once we know the size. + head.encode(0, dst); + + let payload_pos = dst.get_ref().len(); + + f(dst); + + // Now, encode the header payload + let continuation = if self.hpack.len() > dst.remaining_mut() { + dst.put_slice(&self.hpack.split_to(dst.remaining_mut())); + + Some(Continuation { + stream_id: head.stream_id(), + header_block: self, + }) + } else { + dst.put_slice(&self.hpack); + + None + }; + + // Compute the header block length + let payload_len = (dst.get_ref().len() - payload_pos) as u64; + + // Write the frame length + let payload_len_be = payload_len.to_be_bytes(); + assert!(payload_len_be[0..5].iter().all(|b| *b == 0)); + (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]); + + if continuation.is_some() { + // There will be continuation frames, so the `is_end_headers` flag + // must be unset + debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS); + + dst.get_mut()[head_pos + 4] -= END_HEADERS; + } + + continuation + } +} + +// ===== impl Iter ===== + +impl Iterator for Iter { + type Item = hpack::Header<Option<HeaderName>>; + + fn next(&mut self) -> Option<Self::Item> { + use crate::hpack::Header::*; + + if let Some(ref mut pseudo) = self.pseudo { + if let Some(method) = pseudo.method.take() { + return Some(Method(method)); + } + + if let Some(scheme) = pseudo.scheme.take() { + return Some(Scheme(scheme)); + } + + if let Some(authority) = pseudo.authority.take() { + return Some(Authority(authority)); + } + + if let Some(path) = pseudo.path.take() { + return Some(Path(path)); + } + + if let Some(protocol) = pseudo.protocol.take() { + return Some(Protocol(protocol)); + } + + if let Some(status) = pseudo.status.take() { + return Some(Status(status)); + } + } + + self.pseudo = None; + + self.fields + .next() + .map(|(name, value)| Field { name, value }) + } +} + +// ===== impl HeadersFlag ===== + +impl HeadersFlag { + pub fn empty() -> HeadersFlag { + HeadersFlag(0) + } + + pub fn load(bits: u8) -> HeadersFlag { + HeadersFlag(bits & ALL) + } + + pub fn is_end_stream(&self) -> bool { + self.0 & END_STREAM == END_STREAM + } + + pub fn set_end_stream(&mut self) { + self.0 |= END_STREAM; + } + + pub fn is_end_headers(&self) -> bool { + self.0 & END_HEADERS == END_HEADERS + } + + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + + pub fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } + + pub fn is_priority(&self) -> bool { + self.0 & PRIORITY == PRIORITY + } +} + +impl Default for HeadersFlag { + /// Returns a `HeadersFlag` value with `END_HEADERS` set. + fn default() -> Self { + HeadersFlag(END_HEADERS) + } +} + +impl From<HeadersFlag> for u8 { + fn from(src: HeadersFlag) -> u8 { + src.0 + } +} + +impl fmt::Debug for HeadersFlag { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(fmt, self.0) + .flag_if(self.is_end_headers(), "END_HEADERS") + .flag_if(self.is_end_stream(), "END_STREAM") + .flag_if(self.is_padded(), "PADDED") + .flag_if(self.is_priority(), "PRIORITY") + .finish() + } +} + +// ===== impl PushPromiseFlag ===== + +impl PushPromiseFlag { + pub fn empty() -> PushPromiseFlag { + PushPromiseFlag(0) + } + + pub fn load(bits: u8) -> PushPromiseFlag { + PushPromiseFlag(bits & ALL) + } + + pub fn is_end_headers(&self) -> bool { + self.0 & END_HEADERS == END_HEADERS + } + + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + + pub fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } +} + +impl Default for PushPromiseFlag { + /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. + fn default() -> Self { + PushPromiseFlag(END_HEADERS) + } +} + +impl From<PushPromiseFlag> for u8 { + fn from(src: PushPromiseFlag) -> u8 { + src.0 + } +} + +impl fmt::Debug for PushPromiseFlag { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(fmt, self.0) + .flag_if(self.is_end_headers(), "END_HEADERS") + .flag_if(self.is_padded(), "PADDED") + .finish() + } +} + +// ===== HeaderBlock ===== + +impl HeaderBlock { + fn load( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + let mut reg = !self.fields.is_empty(); + let mut malformed = false; + let mut headers_size = self.calculate_header_list_size(); + + macro_rules! set_pseudo { + ($field:ident, $val:expr) => {{ + if reg { + tracing::trace!("load_hpack; header malformed -- pseudo not at head of block"); + malformed = true; + } else if self.pseudo.$field.is_some() { + tracing::trace!("load_hpack; header malformed -- repeated pseudo"); + malformed = true; + } else { + let __val = $val; + headers_size += + decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); + if headers_size < max_header_list_size { + self.pseudo.$field = Some(__val); + } else if !self.is_over_size { + tracing::trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } + } + }}; + } + + let mut cursor = Cursor::new(src); + + // If the header frame is malformed, we still have to continue decoding + // the headers. A malformed header frame is a stream level error, but + // the hpack state is connection level. In order to maintain correct + // state for other streams, the hpack decoding process must complete. + let res = decoder.decode(&mut cursor, |header| { + use crate::hpack::Header::*; + + match header { + Field { name, value } => { + // Connection level header fields are not supported and must + // result in a protocol error. + + if name == header::CONNECTION + || name == header::TRANSFER_ENCODING + || name == header::UPGRADE + || name == "keep-alive" + || name == "proxy-connection" + { + tracing::trace!("load_hpack; connection level header"); + malformed = true; + } else if name == header::TE && value != "trailers" { + tracing::trace!( + "load_hpack; TE header not set to trailers; val={:?}", + value + ); + malformed = true; + } else { + reg = true; + + headers_size += decoded_header_size(name.as_str().len(), value.len()); + if headers_size < max_header_list_size { + self.fields.append(name, value); + } else if !self.is_over_size { + tracing::trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } + } + } + Authority(v) => set_pseudo!(authority, v), + Method(v) => set_pseudo!(method, v), + Scheme(v) => set_pseudo!(scheme, v), + Path(v) => set_pseudo!(path, v), + Protocol(v) => set_pseudo!(protocol, v), + Status(v) => set_pseudo!(status, v), + } + }); + + if let Err(e) = res { + tracing::trace!("hpack decoding error; err={:?}", e); + return Err(e.into()); + } + + if malformed { + tracing::trace!("malformed message"); + return Err(Error::MalformedMessage); + } + + Ok(()) + } + + fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { + let mut hpack = BytesMut::new(); + let headers = Iter { + pseudo: Some(self.pseudo), + fields: self.fields.into_iter(), + }; + + encoder.encode(headers, &mut hpack); + + EncodingHeaderBlock { + hpack: hpack.freeze(), + } + } + + /// Calculates the size of the currently decoded header list. + /// + /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE + /// + /// > The value is based on the uncompressed size of header fields, + /// > including the length of the name and value in octets plus an + /// > overhead of 32 octets for each header field. + fn calculate_header_list_size(&self) -> usize { + macro_rules! pseudo_size { + ($name:ident) => {{ + self.pseudo + .$name + .as_ref() + .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) + .unwrap_or(0) + }}; + } + + pseudo_size!(method) + + pseudo_size!(scheme) + + pseudo_size!(status) + + pseudo_size!(authority) + + pseudo_size!(path) + + self + .fields + .iter() + .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) + .sum::<usize>() + } +} + +fn decoded_header_size(name: usize, value: usize) -> usize { + name + value + 32 +} + +#[cfg(test)] +mod test { + use std::iter::FromIterator; + + use http::HeaderValue; + + use super::*; + use crate::frame; + use crate::hpack::{huffman, Encoder}; + + #[test] + fn test_nameless_header_at_resume() { + let mut encoder = Encoder::default(); + let mut dst = BytesMut::new(); + + let headers = Headers::new( + StreamId::ZERO, + Default::default(), + HeaderMap::from_iter(vec![ + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("world"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("zomg"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("sup"), + ), + ]), + ); + + let continuation = headers + .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) + .unwrap(); + + assert_eq!(17, dst.len()); + assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); + assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); + assert_eq!("hello", huff_decode(&dst[11..15])); + assert_eq!(0x80 | 4, dst[15]); + + let mut world = dst[16..17].to_owned(); + + dst.clear(); + + assert!(continuation + .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) + .is_none()); + + world.extend_from_slice(&dst[9..12]); + assert_eq!("world", huff_decode(&world)); + + assert_eq!(24, dst.len()); + assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); + + // // Next is not indexed + assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); + assert_eq!("zomg", huff_decode(&dst[15..18])); + assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); + assert_eq!("sup", huff_decode(&dst[21..])); + } + + fn huff_decode(src: &[u8]) -> BytesMut { + let mut buf = BytesMut::new(); + huffman::decode(src, &mut buf).unwrap() + } +} diff --git a/third_party/rust/h2/src/frame/mod.rs b/third_party/rust/h2/src/frame/mod.rs new file mode 100644 index 0000000000..570a162a8d --- /dev/null +++ b/third_party/rust/h2/src/frame/mod.rs @@ -0,0 +1,171 @@ +use crate::hpack; + +use bytes::Bytes; + +use std::fmt; + +/// A helper macro that unpacks a sequence of 4 bytes found in the buffer with +/// the given identifier, starting at the given offset, into the given integer +/// type. Obviously, the integer type should be able to support at least 4 +/// bytes. +/// +/// # Examples +/// +/// ```ignore +/// # // We ignore this doctest because the macro is not exported. +/// let buf: [u8; 4] = [0, 0, 0, 1]; +/// assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); +/// ``` +macro_rules! unpack_octets_4 { + // TODO: Get rid of this macro + ($buf:expr, $offset:expr, $tip:ty) => { + (($buf[$offset + 0] as $tip) << 24) + | (($buf[$offset + 1] as $tip) << 16) + | (($buf[$offset + 2] as $tip) << 8) + | (($buf[$offset + 3] as $tip) << 0) + }; +} + +#[cfg(test)] +mod tests { + #[test] + fn test_unpack_octets_4() { + let buf: [u8; 4] = [0, 0, 0, 1]; + assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); + } +} + +mod data; +mod go_away; +mod head; +mod headers; +mod ping; +mod priority; +mod reason; +mod reset; +mod settings; +mod stream_id; +mod util; +mod window_update; + +pub use self::data::Data; +pub use self::go_away::GoAway; +pub use self::head::{Head, Kind}; +pub use self::headers::{ + parse_u64, Continuation, Headers, Pseudo, PushPromise, PushPromiseHeaderError, +}; +pub use self::ping::Ping; +pub use self::priority::{Priority, StreamDependency}; +pub use self::reason::Reason; +pub use self::reset::Reset; +pub use self::settings::Settings; +pub use self::stream_id::{StreamId, StreamIdOverflow}; +pub use self::window_update::WindowUpdate; + +#[cfg(feature = "unstable")] +pub use crate::hpack::BytesStr; + +// Re-export some constants + +pub use self::settings::{ + DEFAULT_INITIAL_WINDOW_SIZE, DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, + MAX_INITIAL_WINDOW_SIZE, MAX_MAX_FRAME_SIZE, +}; + +pub type FrameSize = u32; + +pub const HEADER_LEN: usize = 9; + +#[derive(Eq, PartialEq)] +pub enum Frame<T = Bytes> { + Data(Data<T>), + Headers(Headers), + Priority(Priority), + PushPromise(PushPromise), + Settings(Settings), + Ping(Ping), + GoAway(GoAway), + WindowUpdate(WindowUpdate), + Reset(Reset), +} + +impl<T> Frame<T> { + pub fn map<F, U>(self, f: F) -> Frame<U> + where + F: FnOnce(T) -> U, + { + use self::Frame::*; + + match self { + Data(frame) => frame.map(f).into(), + Headers(frame) => frame.into(), + Priority(frame) => frame.into(), + PushPromise(frame) => frame.into(), + Settings(frame) => frame.into(), + Ping(frame) => frame.into(), + GoAway(frame) => frame.into(), + WindowUpdate(frame) => frame.into(), + Reset(frame) => frame.into(), + } + } +} + +impl<T> fmt::Debug for Frame<T> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use self::Frame::*; + + match *self { + Data(ref frame) => fmt::Debug::fmt(frame, fmt), + Headers(ref frame) => fmt::Debug::fmt(frame, fmt), + Priority(ref frame) => fmt::Debug::fmt(frame, fmt), + PushPromise(ref frame) => fmt::Debug::fmt(frame, fmt), + Settings(ref frame) => fmt::Debug::fmt(frame, fmt), + Ping(ref frame) => fmt::Debug::fmt(frame, fmt), + GoAway(ref frame) => fmt::Debug::fmt(frame, fmt), + WindowUpdate(ref frame) => fmt::Debug::fmt(frame, fmt), + Reset(ref frame) => fmt::Debug::fmt(frame, fmt), + } + } +} + +/// Errors that can occur during parsing an HTTP/2 frame. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Error { + /// A length value other than 8 was set on a PING message. + BadFrameSize, + + /// The padding length was larger than the frame-header-specified + /// length of the payload. + TooMuchPadding, + + /// An invalid setting value was provided + InvalidSettingValue, + + /// An invalid window update value + InvalidWindowUpdateValue, + + /// The payload length specified by the frame header was not the + /// value necessary for the specific frame type. + InvalidPayloadLength, + + /// Received a payload with an ACK settings frame + InvalidPayloadAckSettings, + + /// An invalid stream identifier was provided. + /// + /// This is returned if a SETTINGS or PING frame is received with a stream + /// identifier other than zero. + InvalidStreamId, + + /// A request or response is malformed. + MalformedMessage, + + /// An invalid stream dependency ID was provided + /// + /// This is returned if a HEADERS or PRIORITY frame is received with an + /// invalid stream identifier. + InvalidDependencyId, + + /// Failed to perform HPACK decoding + Hpack(hpack::DecoderError), +} diff --git a/third_party/rust/h2/src/frame/ping.rs b/third_party/rust/h2/src/frame/ping.rs new file mode 100644 index 0000000000..241d06ea17 --- /dev/null +++ b/third_party/rust/h2/src/frame/ping.rs @@ -0,0 +1,102 @@ +use crate::frame::{Error, Frame, Head, Kind, StreamId}; +use bytes::BufMut; + +const ACK_FLAG: u8 = 0x1; + +pub type Payload = [u8; 8]; + +#[derive(Debug, Eq, PartialEq)] +pub struct Ping { + ack: bool, + payload: Payload, +} + +// This was just 8 randomly generated bytes. We use something besides just +// zeroes to distinguish this specific PING from any other. +const SHUTDOWN_PAYLOAD: Payload = [0x0b, 0x7b, 0xa2, 0xf0, 0x8b, 0x9b, 0xfe, 0x54]; +const USER_PAYLOAD: Payload = [0x3b, 0x7c, 0xdb, 0x7a, 0x0b, 0x87, 0x16, 0xb4]; + +impl Ping { + #[cfg(feature = "unstable")] + pub const SHUTDOWN: Payload = SHUTDOWN_PAYLOAD; + + #[cfg(not(feature = "unstable"))] + pub(crate) const SHUTDOWN: Payload = SHUTDOWN_PAYLOAD; + + #[cfg(feature = "unstable")] + pub const USER: Payload = USER_PAYLOAD; + + #[cfg(not(feature = "unstable"))] + pub(crate) const USER: Payload = USER_PAYLOAD; + + pub fn new(payload: Payload) -> Ping { + Ping { + ack: false, + payload, + } + } + + pub fn pong(payload: Payload) -> Ping { + Ping { ack: true, payload } + } + + pub fn is_ack(&self) -> bool { + self.ack + } + + pub fn payload(&self) -> &Payload { + &self.payload + } + + pub fn into_payload(self) -> Payload { + self.payload + } + + /// Builds a `Ping` frame from a raw frame. + pub fn load(head: Head, bytes: &[u8]) -> Result<Ping, Error> { + debug_assert_eq!(head.kind(), crate::frame::Kind::Ping); + + // PING frames are not associated with any individual stream. If a PING + // frame is received with a stream identifier field value other than + // 0x0, the recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + if !head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // In addition to the frame header, PING frames MUST contain 8 octets of opaque + // data in the payload. + if bytes.len() != 8 { + return Err(Error::BadFrameSize); + } + + let mut payload = [0; 8]; + payload.copy_from_slice(bytes); + + // The PING frame defines the following flags: + // + // ACK (0x1): When set, bit 0 indicates that this PING frame is a PING + // response. An endpoint MUST set this flag in PING responses. An + // endpoint MUST NOT respond to PING frames containing this flag. + let ack = head.flag() & ACK_FLAG != 0; + + Ok(Ping { ack, payload }) + } + + pub fn encode<B: BufMut>(&self, dst: &mut B) { + let sz = self.payload.len(); + tracing::trace!("encoding PING; ack={} len={}", self.ack, sz); + + let flags = if self.ack { ACK_FLAG } else { 0 }; + let head = Head::new(Kind::Ping, flags, StreamId::zero()); + + head.encode(sz, dst); + dst.put_slice(&self.payload); + } +} + +impl<T> From<Ping> for Frame<T> { + fn from(src: Ping) -> Frame<T> { + Frame::Ping(src) + } +} diff --git a/third_party/rust/h2/src/frame/priority.rs b/third_party/rust/h2/src/frame/priority.rs new file mode 100644 index 0000000000..d7d47dbb01 --- /dev/null +++ b/third_party/rust/h2/src/frame/priority.rs @@ -0,0 +1,72 @@ +use crate::frame::*; + +#[derive(Debug, Eq, PartialEq)] +pub struct Priority { + stream_id: StreamId, + dependency: StreamDependency, +} + +#[derive(Debug, Eq, PartialEq)] +pub struct StreamDependency { + /// The ID of the stream dependency target + dependency_id: StreamId, + + /// The weight for the stream. The value exposed (and set) here is always in + /// the range [0, 255], instead of [1, 256] (as defined in section 5.3.2.) + /// so that the value fits into a `u8`. + weight: u8, + + /// True if the stream dependency is exclusive. + is_exclusive: bool, +} + +impl Priority { + pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> { + let dependency = StreamDependency::load(payload)?; + + if dependency.dependency_id() == head.stream_id() { + return Err(Error::InvalidDependencyId); + } + + Ok(Priority { + stream_id: head.stream_id(), + dependency, + }) + } +} + +impl<B> From<Priority> for Frame<B> { + fn from(src: Priority) -> Self { + Frame::Priority(src) + } +} + +// ===== impl StreamDependency ===== + +impl StreamDependency { + pub fn new(dependency_id: StreamId, weight: u8, is_exclusive: bool) -> Self { + StreamDependency { + dependency_id, + weight, + is_exclusive, + } + } + + pub fn load(src: &[u8]) -> Result<Self, Error> { + if src.len() != 5 { + return Err(Error::InvalidPayloadLength); + } + + // Parse the stream ID and exclusive flag + let (dependency_id, is_exclusive) = StreamId::parse(&src[..4]); + + // Read the weight + let weight = src[4]; + + Ok(StreamDependency::new(dependency_id, weight, is_exclusive)) + } + + pub fn dependency_id(&self) -> StreamId { + self.dependency_id + } +} diff --git a/third_party/rust/h2/src/frame/reason.rs b/third_party/rust/h2/src/frame/reason.rs new file mode 100644 index 0000000000..ff5e2012f8 --- /dev/null +++ b/third_party/rust/h2/src/frame/reason.rs @@ -0,0 +1,134 @@ +use std::fmt; + +/// HTTP/2 error codes. +/// +/// Error codes are used in `RST_STREAM` and `GOAWAY` frames to convey the +/// reasons for the stream or connection error. For example, +/// [`SendStream::send_reset`] takes a `Reason` argument. Also, the `Error` type +/// may contain a `Reason`. +/// +/// Error codes share a common code space. Some error codes apply only to +/// streams, others apply only to connections, and others may apply to either. +/// See [RFC 7540] for more information. +/// +/// See [Error Codes in the spec][spec]. +/// +/// [spec]: http://httpwg.org/specs/rfc7540.html#ErrorCodes +/// [`SendStream::send_reset`]: struct.SendStream.html#method.send_reset +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct Reason(u32); + +impl Reason { + /// The associated condition is not a result of an error. + /// + /// For example, a GOAWAY might include this code to indicate graceful + /// shutdown of a connection. + pub const NO_ERROR: Reason = Reason(0); + /// The endpoint detected an unspecific protocol error. + /// + /// This error is for use when a more specific error code is not available. + pub const PROTOCOL_ERROR: Reason = Reason(1); + /// The endpoint encountered an unexpected internal error. + pub const INTERNAL_ERROR: Reason = Reason(2); + /// The endpoint detected that its peer violated the flow-control protocol. + pub const FLOW_CONTROL_ERROR: Reason = Reason(3); + /// The endpoint sent a SETTINGS frame but did not receive a response in + /// a timely manner. + pub const SETTINGS_TIMEOUT: Reason = Reason(4); + /// The endpoint received a frame after a stream was half-closed. + pub const STREAM_CLOSED: Reason = Reason(5); + /// The endpoint received a frame with an invalid size. + pub const FRAME_SIZE_ERROR: Reason = Reason(6); + /// The endpoint refused the stream prior to performing any application + /// processing. + pub const REFUSED_STREAM: Reason = Reason(7); + /// Used by the endpoint to indicate that the stream is no longer needed. + pub const CANCEL: Reason = Reason(8); + /// The endpoint is unable to maintain the header compression context for + /// the connection. + pub const COMPRESSION_ERROR: Reason = Reason(9); + /// The connection established in response to a CONNECT request was reset + /// or abnormally closed. + pub const CONNECT_ERROR: Reason = Reason(10); + /// The endpoint detected that its peer is exhibiting a behavior that might + /// be generating excessive load. + pub const ENHANCE_YOUR_CALM: Reason = Reason(11); + /// The underlying transport has properties that do not meet minimum + /// security requirements. + pub const INADEQUATE_SECURITY: Reason = Reason(12); + /// The endpoint requires that HTTP/1.1 be used instead of HTTP/2. + pub const HTTP_1_1_REQUIRED: Reason = Reason(13); + + /// Get a string description of the error code. + pub fn description(&self) -> &str { + match self.0 { + 0 => "not a result of an error", + 1 => "unspecific protocol error detected", + 2 => "unexpected internal error encountered", + 3 => "flow-control protocol violated", + 4 => "settings ACK not received in timely manner", + 5 => "received frame when stream half-closed", + 6 => "frame with invalid size", + 7 => "refused stream before processing any application logic", + 8 => "stream no longer needed", + 9 => "unable to maintain the header compression context", + 10 => { + "connection established in response to a CONNECT request was reset or abnormally \ + closed" + } + 11 => "detected excessive load generating behavior", + 12 => "security properties do not meet minimum requirements", + 13 => "endpoint requires HTTP/1.1", + _ => "unknown reason", + } + } +} + +impl From<u32> for Reason { + fn from(src: u32) -> Reason { + Reason(src) + } +} + +impl From<Reason> for u32 { + fn from(src: Reason) -> u32 { + src.0 + } +} + +impl fmt::Debug for Reason { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let name = match self.0 { + 0 => "NO_ERROR", + 1 => "PROTOCOL_ERROR", + 2 => "INTERNAL_ERROR", + 3 => "FLOW_CONTROL_ERROR", + 4 => "SETTINGS_TIMEOUT", + 5 => "STREAM_CLOSED", + 6 => "FRAME_SIZE_ERROR", + 7 => "REFUSED_STREAM", + 8 => "CANCEL", + 9 => "COMPRESSION_ERROR", + 10 => "CONNECT_ERROR", + 11 => "ENHANCE_YOUR_CALM", + 12 => "INADEQUATE_SECURITY", + 13 => "HTTP_1_1_REQUIRED", + other => return f.debug_tuple("Reason").field(&Hex(other)).finish(), + }; + f.write_str(name) + } +} + +struct Hex(u32); + +impl fmt::Debug for Hex { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::LowerHex::fmt(&self.0, f) + } +} + +impl fmt::Display for Reason { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}", self.description()) + } +} diff --git a/third_party/rust/h2/src/frame/reset.rs b/third_party/rust/h2/src/frame/reset.rs new file mode 100644 index 0000000000..39f6ac2022 --- /dev/null +++ b/third_party/rust/h2/src/frame/reset.rs @@ -0,0 +1,56 @@ +use crate::frame::{self, Error, Head, Kind, Reason, StreamId}; + +use bytes::BufMut; + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct Reset { + stream_id: StreamId, + error_code: Reason, +} + +impl Reset { + pub fn new(stream_id: StreamId, error: Reason) -> Reset { + Reset { + stream_id, + error_code: error, + } + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn reason(&self) -> Reason { + self.error_code + } + + pub fn load(head: Head, payload: &[u8]) -> Result<Reset, Error> { + if payload.len() != 4 { + return Err(Error::InvalidPayloadLength); + } + + let error_code = unpack_octets_4!(payload, 0, u32); + + Ok(Reset { + stream_id: head.stream_id(), + error_code: error_code.into(), + }) + } + + pub fn encode<B: BufMut>(&self, dst: &mut B) { + tracing::trace!( + "encoding RESET; id={:?} code={:?}", + self.stream_id, + self.error_code + ); + let head = Head::new(Kind::Reset, 0, self.stream_id); + head.encode(4, dst); + dst.put_u32(self.error_code.into()); + } +} + +impl<B> From<Reset> for frame::Frame<B> { + fn from(src: Reset) -> Self { + frame::Frame::Reset(src) + } +} diff --git a/third_party/rust/h2/src/frame/settings.rs b/third_party/rust/h2/src/frame/settings.rs new file mode 100644 index 0000000000..080d0f4e58 --- /dev/null +++ b/third_party/rust/h2/src/frame/settings.rs @@ -0,0 +1,391 @@ +use std::fmt; + +use crate::frame::{util, Error, Frame, FrameSize, Head, Kind, StreamId}; +use bytes::{BufMut, BytesMut}; + +#[derive(Clone, Default, Eq, PartialEq)] +pub struct Settings { + flags: SettingsFlags, + // Fields + header_table_size: Option<u32>, + enable_push: Option<u32>, + max_concurrent_streams: Option<u32>, + initial_window_size: Option<u32>, + max_frame_size: Option<u32>, + max_header_list_size: Option<u32>, + enable_connect_protocol: Option<u32>, +} + +/// An enum that lists all valid settings that can be sent in a SETTINGS +/// frame. +/// +/// Each setting has a value that is a 32 bit unsigned integer (6.5.1.). +#[derive(Debug)] +pub enum Setting { + HeaderTableSize(u32), + EnablePush(u32), + MaxConcurrentStreams(u32), + InitialWindowSize(u32), + MaxFrameSize(u32), + MaxHeaderListSize(u32), + EnableConnectProtocol(u32), +} + +#[derive(Copy, Clone, Eq, PartialEq, Default)] +pub struct SettingsFlags(u8); + +const ACK: u8 = 0x1; +const ALL: u8 = ACK; + +/// The default value of SETTINGS_HEADER_TABLE_SIZE +pub const DEFAULT_SETTINGS_HEADER_TABLE_SIZE: usize = 4_096; + +/// The default value of SETTINGS_INITIAL_WINDOW_SIZE +pub const DEFAULT_INITIAL_WINDOW_SIZE: u32 = 65_535; + +/// The default value of MAX_FRAME_SIZE +pub const DEFAULT_MAX_FRAME_SIZE: FrameSize = 16_384; + +/// INITIAL_WINDOW_SIZE upper bound +pub const MAX_INITIAL_WINDOW_SIZE: usize = (1 << 31) - 1; + +/// MAX_FRAME_SIZE upper bound +pub const MAX_MAX_FRAME_SIZE: FrameSize = (1 << 24) - 1; + +// ===== impl Settings ===== + +impl Settings { + pub fn ack() -> Settings { + Settings { + flags: SettingsFlags::ack(), + ..Settings::default() + } + } + + pub fn is_ack(&self) -> bool { + self.flags.is_ack() + } + + pub fn initial_window_size(&self) -> Option<u32> { + self.initial_window_size + } + + pub fn set_initial_window_size(&mut self, size: Option<u32>) { + self.initial_window_size = size; + } + + pub fn max_concurrent_streams(&self) -> Option<u32> { + self.max_concurrent_streams + } + + pub fn set_max_concurrent_streams(&mut self, max: Option<u32>) { + self.max_concurrent_streams = max; + } + + pub fn max_frame_size(&self) -> Option<u32> { + self.max_frame_size + } + + pub fn set_max_frame_size(&mut self, size: Option<u32>) { + if let Some(val) = size { + assert!(DEFAULT_MAX_FRAME_SIZE <= val && val <= MAX_MAX_FRAME_SIZE); + } + self.max_frame_size = size; + } + + pub fn max_header_list_size(&self) -> Option<u32> { + self.max_header_list_size + } + + pub fn set_max_header_list_size(&mut self, size: Option<u32>) { + self.max_header_list_size = size; + } + + pub fn is_push_enabled(&self) -> Option<bool> { + self.enable_push.map(|val| val != 0) + } + + pub fn set_enable_push(&mut self, enable: bool) { + self.enable_push = Some(enable as u32); + } + + pub fn is_extended_connect_protocol_enabled(&self) -> Option<bool> { + self.enable_connect_protocol.map(|val| val != 0) + } + + pub fn set_enable_connect_protocol(&mut self, val: Option<u32>) { + self.enable_connect_protocol = val; + } + + pub fn header_table_size(&self) -> Option<u32> { + self.header_table_size + } + + /* + pub fn set_header_table_size(&mut self, size: Option<u32>) { + self.header_table_size = size; + } + */ + + pub fn load(head: Head, payload: &[u8]) -> Result<Settings, Error> { + use self::Setting::*; + + debug_assert_eq!(head.kind(), crate::frame::Kind::Settings); + + if !head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // Load the flag + let flag = SettingsFlags::load(head.flag()); + + if flag.is_ack() { + // Ensure that the payload is empty + if !payload.is_empty() { + return Err(Error::InvalidPayloadLength); + } + + // Return the ACK frame + return Ok(Settings::ack()); + } + + // Ensure the payload length is correct, each setting is 6 bytes long. + if payload.len() % 6 != 0 { + tracing::debug!("invalid settings payload length; len={:?}", payload.len()); + return Err(Error::InvalidPayloadAckSettings); + } + + let mut settings = Settings::default(); + debug_assert!(!settings.flags.is_ack()); + + for raw in payload.chunks(6) { + match Setting::load(raw) { + Some(HeaderTableSize(val)) => { + settings.header_table_size = Some(val); + } + Some(EnablePush(val)) => match val { + 0 | 1 => { + settings.enable_push = Some(val); + } + _ => { + return Err(Error::InvalidSettingValue); + } + }, + Some(MaxConcurrentStreams(val)) => { + settings.max_concurrent_streams = Some(val); + } + Some(InitialWindowSize(val)) => { + if val as usize > MAX_INITIAL_WINDOW_SIZE { + return Err(Error::InvalidSettingValue); + } else { + settings.initial_window_size = Some(val); + } + } + Some(MaxFrameSize(val)) => { + if val < DEFAULT_MAX_FRAME_SIZE || val > MAX_MAX_FRAME_SIZE { + return Err(Error::InvalidSettingValue); + } else { + settings.max_frame_size = Some(val); + } + } + Some(MaxHeaderListSize(val)) => { + settings.max_header_list_size = Some(val); + } + Some(EnableConnectProtocol(val)) => match val { + 0 | 1 => { + settings.enable_connect_protocol = Some(val); + } + _ => { + return Err(Error::InvalidSettingValue); + } + }, + None => {} + } + } + + Ok(settings) + } + + fn payload_len(&self) -> usize { + let mut len = 0; + self.for_each(|_| len += 6); + len + } + + pub fn encode(&self, dst: &mut BytesMut) { + // Create & encode an appropriate frame head + let head = Head::new(Kind::Settings, self.flags.into(), StreamId::zero()); + let payload_len = self.payload_len(); + + tracing::trace!("encoding SETTINGS; len={}", payload_len); + + head.encode(payload_len, dst); + + // Encode the settings + self.for_each(|setting| { + tracing::trace!("encoding setting; val={:?}", setting); + setting.encode(dst) + }); + } + + fn for_each<F: FnMut(Setting)>(&self, mut f: F) { + use self::Setting::*; + + if let Some(v) = self.header_table_size { + f(HeaderTableSize(v)); + } + + if let Some(v) = self.enable_push { + f(EnablePush(v)); + } + + if let Some(v) = self.max_concurrent_streams { + f(MaxConcurrentStreams(v)); + } + + if let Some(v) = self.initial_window_size { + f(InitialWindowSize(v)); + } + + if let Some(v) = self.max_frame_size { + f(MaxFrameSize(v)); + } + + if let Some(v) = self.max_header_list_size { + f(MaxHeaderListSize(v)); + } + + if let Some(v) = self.enable_connect_protocol { + f(EnableConnectProtocol(v)); + } + } +} + +impl<T> From<Settings> for Frame<T> { + fn from(src: Settings) -> Frame<T> { + Frame::Settings(src) + } +} + +impl fmt::Debug for Settings { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut builder = f.debug_struct("Settings"); + builder.field("flags", &self.flags); + + self.for_each(|setting| match setting { + Setting::EnablePush(v) => { + builder.field("enable_push", &v); + } + Setting::HeaderTableSize(v) => { + builder.field("header_table_size", &v); + } + Setting::InitialWindowSize(v) => { + builder.field("initial_window_size", &v); + } + Setting::MaxConcurrentStreams(v) => { + builder.field("max_concurrent_streams", &v); + } + Setting::MaxFrameSize(v) => { + builder.field("max_frame_size", &v); + } + Setting::MaxHeaderListSize(v) => { + builder.field("max_header_list_size", &v); + } + Setting::EnableConnectProtocol(v) => { + builder.field("enable_connect_protocol", &v); + } + }); + + builder.finish() + } +} + +// ===== impl Setting ===== + +impl Setting { + /// Creates a new `Setting` with the correct variant corresponding to the + /// given setting id, based on the settings IDs defined in section + /// 6.5.2. + pub fn from_id(id: u16, val: u32) -> Option<Setting> { + use self::Setting::*; + + match id { + 1 => Some(HeaderTableSize(val)), + 2 => Some(EnablePush(val)), + 3 => Some(MaxConcurrentStreams(val)), + 4 => Some(InitialWindowSize(val)), + 5 => Some(MaxFrameSize(val)), + 6 => Some(MaxHeaderListSize(val)), + 8 => Some(EnableConnectProtocol(val)), + _ => None, + } + } + + /// Creates a new `Setting` by parsing the given buffer of 6 bytes, which + /// contains the raw byte representation of the setting, according to the + /// "SETTINGS format" defined in section 6.5.1. + /// + /// The `raw` parameter should have length at least 6 bytes, since the + /// length of the raw setting is exactly 6 bytes. + /// + /// # Panics + /// + /// If given a buffer shorter than 6 bytes, the function will panic. + fn load(raw: &[u8]) -> Option<Setting> { + let id: u16 = (u16::from(raw[0]) << 8) | u16::from(raw[1]); + let val: u32 = unpack_octets_4!(raw, 2, u32); + + Setting::from_id(id, val) + } + + fn encode(&self, dst: &mut BytesMut) { + use self::Setting::*; + + let (kind, val) = match *self { + HeaderTableSize(v) => (1, v), + EnablePush(v) => (2, v), + MaxConcurrentStreams(v) => (3, v), + InitialWindowSize(v) => (4, v), + MaxFrameSize(v) => (5, v), + MaxHeaderListSize(v) => (6, v), + EnableConnectProtocol(v) => (8, v), + }; + + dst.put_u16(kind); + dst.put_u32(val); + } +} + +// ===== impl SettingsFlags ===== + +impl SettingsFlags { + pub fn empty() -> SettingsFlags { + SettingsFlags(0) + } + + pub fn load(bits: u8) -> SettingsFlags { + SettingsFlags(bits & ALL) + } + + pub fn ack() -> SettingsFlags { + SettingsFlags(ACK) + } + + pub fn is_ack(&self) -> bool { + self.0 & ACK == ACK + } +} + +impl From<SettingsFlags> for u8 { + fn from(src: SettingsFlags) -> u8 { + src.0 + } +} + +impl fmt::Debug for SettingsFlags { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(f, self.0) + .flag_if(self.is_ack(), "ACK") + .finish() + } +} diff --git a/third_party/rust/h2/src/frame/stream_id.rs b/third_party/rust/h2/src/frame/stream_id.rs new file mode 100644 index 0000000000..10a14d3c82 --- /dev/null +++ b/third_party/rust/h2/src/frame/stream_id.rs @@ -0,0 +1,96 @@ +use std::u32; + +/// A stream identifier, as described in [Section 5.1.1] of RFC 7540. +/// +/// Streams are identified with an unsigned 31-bit integer. Streams +/// initiated by a client MUST use odd-numbered stream identifiers; those +/// initiated by the server MUST use even-numbered stream identifiers. A +/// stream identifier of zero (0x0) is used for connection control +/// messages; the stream identifier of zero cannot be used to establish a +/// new stream. +/// +/// [Section 5.1.1]: https://tools.ietf.org/html/rfc7540#section-5.1.1 +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct StreamId(u32); + +#[derive(Debug, Copy, Clone)] +pub struct StreamIdOverflow; + +const STREAM_ID_MASK: u32 = 1 << 31; + +impl StreamId { + /// Stream ID 0. + pub const ZERO: StreamId = StreamId(0); + + /// The maximum allowed stream ID. + pub const MAX: StreamId = StreamId(u32::MAX >> 1); + + /// Parse the stream ID + #[inline] + pub fn parse(buf: &[u8]) -> (StreamId, bool) { + let mut ubuf = [0; 4]; + ubuf.copy_from_slice(&buf[0..4]); + let unpacked = u32::from_be_bytes(ubuf); + let flag = unpacked & STREAM_ID_MASK == STREAM_ID_MASK; + + // Now clear the most significant bit, as that is reserved and MUST be + // ignored when received. + (StreamId(unpacked & !STREAM_ID_MASK), flag) + } + + /// Returns true if this stream ID corresponds to a stream that + /// was initiated by the client. + pub fn is_client_initiated(&self) -> bool { + let id = self.0; + id != 0 && id % 2 == 1 + } + + /// Returns true if this stream ID corresponds to a stream that + /// was initiated by the server. + pub fn is_server_initiated(&self) -> bool { + let id = self.0; + id != 0 && id % 2 == 0 + } + + /// Return a new `StreamId` for stream 0. + #[inline] + pub fn zero() -> StreamId { + StreamId::ZERO + } + + /// Returns true if this stream ID is zero. + pub fn is_zero(&self) -> bool { + self.0 == 0 + } + + /// Returns the next stream ID initiated by the same peer as this stream + /// ID, or an error if incrementing this stream ID would overflow the + /// maximum. + pub fn next_id(&self) -> Result<StreamId, StreamIdOverflow> { + let next = self.0 + 2; + if next > StreamId::MAX.0 { + Err(StreamIdOverflow) + } else { + Ok(StreamId(next)) + } + } +} + +impl From<u32> for StreamId { + fn from(src: u32) -> Self { + assert_eq!(src & STREAM_ID_MASK, 0, "invalid stream ID -- MSB is set"); + StreamId(src) + } +} + +impl From<StreamId> for u32 { + fn from(src: StreamId) -> Self { + src.0 + } +} + +impl PartialEq<u32> for StreamId { + fn eq(&self, other: &u32) -> bool { + self.0 == *other + } +} diff --git a/third_party/rust/h2/src/frame/util.rs b/third_party/rust/h2/src/frame/util.rs new file mode 100644 index 0000000000..6bee7bd9bb --- /dev/null +++ b/third_party/rust/h2/src/frame/util.rs @@ -0,0 +1,79 @@ +use std::fmt; + +use super::Error; +use bytes::Bytes; + +/// Strip padding from the given payload. +/// +/// It is assumed that the frame had the padded flag set. This means that the +/// first byte is the length of the padding with that many +/// 0 bytes expected to follow the actual payload. +/// +/// # Returns +/// +/// A slice of the given payload where the actual one is found and the length +/// of the padding. +/// +/// If the padded payload is invalid (e.g. the length of the padding is equal +/// to the total length), returns `None`. +pub fn strip_padding(payload: &mut Bytes) -> Result<u8, Error> { + let payload_len = payload.len(); + if payload_len == 0 { + // If this is the case, the frame is invalid as no padding length can be + // extracted, even though the frame should be padded. + return Err(Error::TooMuchPadding); + } + + let pad_len = payload[0] as usize; + + if pad_len >= payload_len { + // This is invalid: the padding length MUST be less than the + // total frame size. + return Err(Error::TooMuchPadding); + } + + let _ = payload.split_to(1); + let _ = payload.split_off(payload_len - pad_len - 1); + + Ok(pad_len as u8) +} + +pub(super) fn debug_flags<'a, 'f: 'a>( + fmt: &'a mut fmt::Formatter<'f>, + bits: u8, +) -> DebugFlags<'a, 'f> { + let result = write!(fmt, "({:#x}", bits); + DebugFlags { + fmt, + result, + started: false, + } +} + +pub(super) struct DebugFlags<'a, 'f: 'a> { + fmt: &'a mut fmt::Formatter<'f>, + result: fmt::Result, + started: bool, +} + +impl<'a, 'f: 'a> DebugFlags<'a, 'f> { + pub(super) fn flag_if(&mut self, enabled: bool, name: &str) -> &mut Self { + if enabled { + self.result = self.result.and_then(|()| { + let prefix = if self.started { + " | " + } else { + self.started = true; + ": " + }; + + write!(self.fmt, "{}{}", prefix, name) + }); + } + self + } + + pub(super) fn finish(&mut self) -> fmt::Result { + self.result.and_then(|()| write!(self.fmt, ")")) + } +} diff --git a/third_party/rust/h2/src/frame/window_update.rs b/third_party/rust/h2/src/frame/window_update.rs new file mode 100644 index 0000000000..eed2ce17ec --- /dev/null +++ b/third_party/rust/h2/src/frame/window_update.rs @@ -0,0 +1,62 @@ +use crate::frame::{self, Error, Head, Kind, StreamId}; + +use bytes::BufMut; + +const SIZE_INCREMENT_MASK: u32 = 1 << 31; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct WindowUpdate { + stream_id: StreamId, + size_increment: u32, +} + +impl WindowUpdate { + pub fn new(stream_id: StreamId, size_increment: u32) -> WindowUpdate { + WindowUpdate { + stream_id, + size_increment, + } + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn size_increment(&self) -> u32 { + self.size_increment + } + + /// Builds a `WindowUpdate` frame from a raw frame. + pub fn load(head: Head, payload: &[u8]) -> Result<WindowUpdate, Error> { + debug_assert_eq!(head.kind(), crate::frame::Kind::WindowUpdate); + if payload.len() != 4 { + return Err(Error::BadFrameSize); + } + + // Clear the most significant bit, as that is reserved and MUST be ignored + // when received. + let size_increment = unpack_octets_4!(payload, 0, u32) & !SIZE_INCREMENT_MASK; + + if size_increment == 0 { + return Err(Error::InvalidWindowUpdateValue); + } + + Ok(WindowUpdate { + stream_id: head.stream_id(), + size_increment, + }) + } + + pub fn encode<B: BufMut>(&self, dst: &mut B) { + tracing::trace!("encoding WINDOW_UPDATE; id={:?}", self.stream_id); + let head = Head::new(Kind::WindowUpdate, 0, self.stream_id); + head.encode(4, dst); + dst.put_u32(self.size_increment); + } +} + +impl<B> From<WindowUpdate> for frame::Frame<B> { + fn from(src: WindowUpdate) -> Self { + frame::Frame::WindowUpdate(src) + } +} |