diff options
Diffstat (limited to 'third_party/rust/h2/src/frame/headers.rs')
-rw-r--r-- | third_party/rust/h2/src/frame/headers.rs | 1039 |
1 files changed, 1039 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/frame/headers.rs b/third_party/rust/h2/src/frame/headers.rs new file mode 100644 index 0000000000..bcb9050133 --- /dev/null +++ b/third_party/rust/h2/src/frame/headers.rs @@ -0,0 +1,1039 @@ +use super::{util, StreamDependency, StreamId}; +use crate::ext::Protocol; +use crate::frame::{Error, Frame, Head, Kind}; +use crate::hpack::{self, BytesStr}; + +use http::header::{self, HeaderName, HeaderValue}; +use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; + +use bytes::{BufMut, Bytes, BytesMut}; + +use std::fmt; +use std::io::Cursor; + +type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; +/// Header frame +/// +/// This could be either a request or a response. +#[derive(Eq, PartialEq)] +pub struct Headers { + /// The ID of the stream with which this frame is associated. + stream_id: StreamId, + + /// The stream dependency information, if any. + stream_dep: Option<StreamDependency>, + + /// The header block fragment + header_block: HeaderBlock, + + /// The associated flags + flags: HeadersFlag, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct HeadersFlag(u8); + +#[derive(Eq, PartialEq)] +pub struct PushPromise { + /// The ID of the stream with which this frame is associated. + stream_id: StreamId, + + /// The ID of the stream being reserved by this PushPromise. + promised_id: StreamId, + + /// The header block fragment + header_block: HeaderBlock, + + /// The associated flags + flags: PushPromiseFlag, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct PushPromiseFlag(u8); + +#[derive(Debug)] +pub struct Continuation { + /// Stream ID of continuation frame + stream_id: StreamId, + + header_block: EncodingHeaderBlock, +} + +// TODO: These fields shouldn't be `pub` +#[derive(Debug, Default, Eq, PartialEq)] +pub struct Pseudo { + // Request + pub method: Option<Method>, + pub scheme: Option<BytesStr>, + pub authority: Option<BytesStr>, + pub path: Option<BytesStr>, + pub protocol: Option<Protocol>, + + // Response + pub status: Option<StatusCode>, +} + +#[derive(Debug)] +pub struct Iter { + /// Pseudo headers + pseudo: Option<Pseudo>, + + /// Header fields + fields: header::IntoIter<HeaderValue>, +} + +#[derive(Debug, PartialEq, Eq)] +struct HeaderBlock { + /// The decoded header fields + fields: HeaderMap, + + /// Set to true if decoding went over the max header list size. + is_over_size: bool, + + /// Pseudo headers, these are broken out as they must be sent as part of the + /// headers frame. + pseudo: Pseudo, +} + +#[derive(Debug)] +struct EncodingHeaderBlock { + hpack: Bytes, +} + +const END_STREAM: u8 = 0x1; +const END_HEADERS: u8 = 0x4; +const PADDED: u8 = 0x8; +const PRIORITY: u8 = 0x20; +const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY; + +// ===== impl Headers ===== + +impl Headers { + /// Create a new HEADERS frame + pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self { + Headers { + stream_id, + stream_dep: None, + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo, + }, + flags: HeadersFlag::default(), + } + } + + pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self { + let mut flags = HeadersFlag::default(); + flags.set_end_stream(); + + Headers { + stream_id, + stream_dep: None, + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo: Pseudo::default(), + }, + flags, + } + } + + /// Loads the header frame but doesn't actually do HPACK decoding. + /// + /// HPACK decoding is done in the `load_hpack` step. + pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { + let flags = HeadersFlag(head.flag()); + let mut pad = 0; + + tracing::trace!("loading headers; flags={:?}", flags); + + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // Read the padding length + if flags.is_padded() { + if src.is_empty() { + return Err(Error::MalformedMessage); + } + pad = src[0] as usize; + + // Drop the padding + let _ = src.split_to(1); + } + + // Read the stream dependency + let stream_dep = if flags.is_priority() { + if src.len() < 5 { + return Err(Error::MalformedMessage); + } + let stream_dep = StreamDependency::load(&src[..5])?; + + if stream_dep.dependency_id() == head.stream_id() { + return Err(Error::InvalidDependencyId); + } + + // Drop the next 5 bytes + let _ = src.split_to(5); + + Some(stream_dep) + } else { + None + }; + + if pad > 0 { + if pad > src.len() { + return Err(Error::TooMuchPadding); + } + + let len = src.len() - pad; + src.truncate(len); + } + + let headers = Headers { + stream_id: head.stream_id(), + stream_dep, + header_block: HeaderBlock { + fields: HeaderMap::new(), + is_over_size: false, + pseudo: Pseudo::default(), + }, + flags, + }; + + Ok((headers, src)) + } + + pub fn load_hpack( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn is_end_headers(&self) -> bool { + self.flags.is_end_headers() + } + + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + + pub fn is_end_stream(&self) -> bool { + self.flags.is_end_stream() + } + + pub fn set_end_stream(&mut self) { + self.flags.set_end_stream() + } + + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.header_block.pseudo, self.header_block.fields) + } + + #[cfg(feature = "unstable")] + pub fn pseudo_mut(&mut self) -> &mut Pseudo { + &mut self.header_block.pseudo + } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.header_block.pseudo.is_informational() + } + + pub fn fields(&self) -> &HeaderMap { + &self.header_block.fields + } + + pub fn into_fields(self) -> HeaderMap { + self.header_block.fields + } + + pub fn encode( + self, + encoder: &mut hpack::Encoder, + dst: &mut EncodeBuf<'_>, + ) -> Option<Continuation> { + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + + // Get the HEADERS frame head + let head = self.head(); + + self.header_block + .into_encoding(encoder) + .encode(&head, dst, |_| {}) + } + + fn head(&self) -> Head { + Head::new(Kind::Headers, self.flags.into(), self.stream_id) + } +} + +impl<T> From<Headers> for Frame<T> { + fn from(src: Headers) -> Self { + Frame::Headers(src) + } +} + +impl fmt::Debug for Headers { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut builder = f.debug_struct("Headers"); + builder + .field("stream_id", &self.stream_id) + .field("flags", &self.flags); + + if let Some(ref protocol) = self.header_block.pseudo.protocol { + builder.field("protocol", protocol); + } + + if let Some(ref dep) = self.stream_dep { + builder.field("stream_dep", dep); + } + + // `fields` and `pseudo` purposefully not included + builder.finish() + } +} + +// ===== util ===== + +pub fn parse_u64(src: &[u8]) -> Result<u64, ()> { + if src.len() > 19 { + // At danger for overflow... + return Err(()); + } + + let mut ret = 0; + + for &d in src { + if d < b'0' || d > b'9' { + return Err(()); + } + + ret *= 10; + ret += (d - b'0') as u64; + } + + Ok(ret) +} + +// ===== impl PushPromise ===== + +#[derive(Debug)] +pub enum PushPromiseHeaderError { + InvalidContentLength(Result<u64, ()>), + NotSafeAndCacheable, +} + +impl PushPromise { + pub fn new( + stream_id: StreamId, + promised_id: StreamId, + pseudo: Pseudo, + fields: HeaderMap, + ) -> Self { + PushPromise { + flags: PushPromiseFlag::default(), + header_block: HeaderBlock { + fields, + is_over_size: false, + pseudo, + }, + promised_id, + stream_id, + } + } + + pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> { + use PushPromiseHeaderError::*; + // The spec has some requirements for promised request headers + // [https://httpwg.org/specs/rfc7540.html#PushRequests] + + // A promised request "that indicates the presence of a request body + // MUST reset the promised stream with a stream error" + if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) { + let parsed_length = parse_u64(content_length.as_bytes()); + if parsed_length != Ok(0) { + return Err(InvalidContentLength(parsed_length)); + } + } + // "The server MUST include a method in the :method pseudo-header field + // that is safe and cacheable" + if !Self::safe_and_cacheable(req.method()) { + return Err(NotSafeAndCacheable); + } + + Ok(()) + } + + fn safe_and_cacheable(method: &Method) -> bool { + // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods + // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods + return method == Method::GET || method == Method::HEAD; + } + + pub fn fields(&self) -> &HeaderMap { + &self.header_block.fields + } + + #[cfg(feature = "unstable")] + pub fn into_fields(self) -> HeaderMap { + self.header_block.fields + } + + /// Loads the push promise frame but doesn't actually do HPACK decoding. + /// + /// HPACK decoding is done in the `load_hpack` step. + pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { + let flags = PushPromiseFlag(head.flag()); + let mut pad = 0; + + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + + // Read the padding length + if flags.is_padded() { + if src.is_empty() { + return Err(Error::MalformedMessage); + } + + // TODO: Ensure payload is sized correctly + pad = src[0] as usize; + + // Drop the padding + let _ = src.split_to(1); + } + + if src.len() < 5 { + return Err(Error::MalformedMessage); + } + + let (promised_id, _) = StreamId::parse(&src[..4]); + // Drop promised_id bytes + let _ = src.split_to(4); + + if pad > 0 { + if pad > src.len() { + return Err(Error::TooMuchPadding); + } + + let len = src.len() - pad; + src.truncate(len); + } + + let frame = PushPromise { + flags, + header_block: HeaderBlock { + fields: HeaderMap::new(), + is_over_size: false, + pseudo: Pseudo::default(), + }, + promised_id, + stream_id: head.stream_id(), + }; + Ok((frame, src)) + } + + pub fn load_hpack( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn promised_id(&self) -> StreamId { + self.promised_id + } + + pub fn is_end_headers(&self) -> bool { + self.flags.is_end_headers() + } + + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + + pub fn encode( + self, + encoder: &mut hpack::Encoder, + dst: &mut EncodeBuf<'_>, + ) -> Option<Continuation> { + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + + let head = self.head(); + let promised_id = self.promised_id; + + self.header_block + .into_encoding(encoder) + .encode(&head, dst, |dst| { + dst.put_u32(promised_id.into()); + }) + } + + fn head(&self) -> Head { + Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) + } + + /// Consume `self`, returning the parts of the frame + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.header_block.pseudo, self.header_block.fields) + } +} + +impl<T> From<PushPromise> for Frame<T> { + fn from(src: PushPromise) -> Self { + Frame::PushPromise(src) + } +} + +impl fmt::Debug for PushPromise { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("PushPromise") + .field("stream_id", &self.stream_id) + .field("promised_id", &self.promised_id) + .field("flags", &self.flags) + // `fields` and `pseudo` purposefully not included + .finish() + } +} + +// ===== impl Continuation ===== + +impl Continuation { + fn head(&self) -> Head { + Head::new(Kind::Continuation, END_HEADERS, self.stream_id) + } + + pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> { + // Get the CONTINUATION frame head + let head = self.head(); + + self.header_block.encode(&head, dst, |_| {}) + } +} + +// ===== impl Pseudo ===== + +impl Pseudo { + pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { + let parts = uri::Parts::from(uri); + + let mut path = parts + .path_and_query + .map(|v| BytesStr::from(v.as_str())) + .unwrap_or(BytesStr::from_static("")); + + match method { + Method::OPTIONS | Method::CONNECT => {} + _ if path.is_empty() => { + path = BytesStr::from_static("/"); + } + _ => {} + } + + let mut pseudo = Pseudo { + method: Some(method), + scheme: None, + authority: None, + path: Some(path).filter(|p| !p.is_empty()), + protocol, + status: None, + }; + + // If the URI includes a scheme component, add it to the pseudo headers + // + // TODO: Scheme must be set... + if let Some(scheme) = parts.scheme { + pseudo.set_scheme(scheme); + } + + // If the URI includes an authority component, add it to the pseudo + // headers + if let Some(authority) = parts.authority { + pseudo.set_authority(BytesStr::from(authority.as_str())); + } + + pseudo + } + + pub fn response(status: StatusCode) -> Self { + Pseudo { + method: None, + scheme: None, + authority: None, + path: None, + protocol: None, + status: Some(status), + } + } + + #[cfg(feature = "unstable")] + pub fn set_status(&mut self, value: StatusCode) { + self.status = Some(value); + } + + pub fn set_scheme(&mut self, scheme: uri::Scheme) { + let bytes_str = match scheme.as_str() { + "http" => BytesStr::from_static("http"), + "https" => BytesStr::from_static("https"), + s => BytesStr::from(s), + }; + self.scheme = Some(bytes_str); + } + + #[cfg(feature = "unstable")] + pub fn set_protocol(&mut self, protocol: Protocol) { + self.protocol = Some(protocol); + } + + pub fn set_authority(&mut self, authority: BytesStr) { + self.authority = Some(authority); + } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.status + .map_or(false, |status| status.is_informational()) + } +} + +// ===== impl EncodingHeaderBlock ===== + +impl EncodingHeaderBlock { + fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> + where + F: FnOnce(&mut EncodeBuf<'_>), + { + let head_pos = dst.get_ref().len(); + + // At this point, we don't know how big the h2 frame will be. + // So, we write the head with length 0, then write the body, and + // finally write the length once we know the size. + head.encode(0, dst); + + let payload_pos = dst.get_ref().len(); + + f(dst); + + // Now, encode the header payload + let continuation = if self.hpack.len() > dst.remaining_mut() { + dst.put_slice(&self.hpack.split_to(dst.remaining_mut())); + + Some(Continuation { + stream_id: head.stream_id(), + header_block: self, + }) + } else { + dst.put_slice(&self.hpack); + + None + }; + + // Compute the header block length + let payload_len = (dst.get_ref().len() - payload_pos) as u64; + + // Write the frame length + let payload_len_be = payload_len.to_be_bytes(); + assert!(payload_len_be[0..5].iter().all(|b| *b == 0)); + (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]); + + if continuation.is_some() { + // There will be continuation frames, so the `is_end_headers` flag + // must be unset + debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS); + + dst.get_mut()[head_pos + 4] -= END_HEADERS; + } + + continuation + } +} + +// ===== impl Iter ===== + +impl Iterator for Iter { + type Item = hpack::Header<Option<HeaderName>>; + + fn next(&mut self) -> Option<Self::Item> { + use crate::hpack::Header::*; + + if let Some(ref mut pseudo) = self.pseudo { + if let Some(method) = pseudo.method.take() { + return Some(Method(method)); + } + + if let Some(scheme) = pseudo.scheme.take() { + return Some(Scheme(scheme)); + } + + if let Some(authority) = pseudo.authority.take() { + return Some(Authority(authority)); + } + + if let Some(path) = pseudo.path.take() { + return Some(Path(path)); + } + + if let Some(protocol) = pseudo.protocol.take() { + return Some(Protocol(protocol)); + } + + if let Some(status) = pseudo.status.take() { + return Some(Status(status)); + } + } + + self.pseudo = None; + + self.fields + .next() + .map(|(name, value)| Field { name, value }) + } +} + +// ===== impl HeadersFlag ===== + +impl HeadersFlag { + pub fn empty() -> HeadersFlag { + HeadersFlag(0) + } + + pub fn load(bits: u8) -> HeadersFlag { + HeadersFlag(bits & ALL) + } + + pub fn is_end_stream(&self) -> bool { + self.0 & END_STREAM == END_STREAM + } + + pub fn set_end_stream(&mut self) { + self.0 |= END_STREAM; + } + + pub fn is_end_headers(&self) -> bool { + self.0 & END_HEADERS == END_HEADERS + } + + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + + pub fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } + + pub fn is_priority(&self) -> bool { + self.0 & PRIORITY == PRIORITY + } +} + +impl Default for HeadersFlag { + /// Returns a `HeadersFlag` value with `END_HEADERS` set. + fn default() -> Self { + HeadersFlag(END_HEADERS) + } +} + +impl From<HeadersFlag> for u8 { + fn from(src: HeadersFlag) -> u8 { + src.0 + } +} + +impl fmt::Debug for HeadersFlag { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(fmt, self.0) + .flag_if(self.is_end_headers(), "END_HEADERS") + .flag_if(self.is_end_stream(), "END_STREAM") + .flag_if(self.is_padded(), "PADDED") + .flag_if(self.is_priority(), "PRIORITY") + .finish() + } +} + +// ===== impl PushPromiseFlag ===== + +impl PushPromiseFlag { + pub fn empty() -> PushPromiseFlag { + PushPromiseFlag(0) + } + + pub fn load(bits: u8) -> PushPromiseFlag { + PushPromiseFlag(bits & ALL) + } + + pub fn is_end_headers(&self) -> bool { + self.0 & END_HEADERS == END_HEADERS + } + + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + + pub fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } +} + +impl Default for PushPromiseFlag { + /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. + fn default() -> Self { + PushPromiseFlag(END_HEADERS) + } +} + +impl From<PushPromiseFlag> for u8 { + fn from(src: PushPromiseFlag) -> u8 { + src.0 + } +} + +impl fmt::Debug for PushPromiseFlag { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + util::debug_flags(fmt, self.0) + .flag_if(self.is_end_headers(), "END_HEADERS") + .flag_if(self.is_padded(), "PADDED") + .finish() + } +} + +// ===== HeaderBlock ===== + +impl HeaderBlock { + fn load( + &mut self, + src: &mut BytesMut, + max_header_list_size: usize, + decoder: &mut hpack::Decoder, + ) -> Result<(), Error> { + let mut reg = !self.fields.is_empty(); + let mut malformed = false; + let mut headers_size = self.calculate_header_list_size(); + + macro_rules! set_pseudo { + ($field:ident, $val:expr) => {{ + if reg { + tracing::trace!("load_hpack; header malformed -- pseudo not at head of block"); + malformed = true; + } else if self.pseudo.$field.is_some() { + tracing::trace!("load_hpack; header malformed -- repeated pseudo"); + malformed = true; + } else { + let __val = $val; + headers_size += + decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); + if headers_size < max_header_list_size { + self.pseudo.$field = Some(__val); + } else if !self.is_over_size { + tracing::trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } + } + }}; + } + + let mut cursor = Cursor::new(src); + + // If the header frame is malformed, we still have to continue decoding + // the headers. A malformed header frame is a stream level error, but + // the hpack state is connection level. In order to maintain correct + // state for other streams, the hpack decoding process must complete. + let res = decoder.decode(&mut cursor, |header| { + use crate::hpack::Header::*; + + match header { + Field { name, value } => { + // Connection level header fields are not supported and must + // result in a protocol error. + + if name == header::CONNECTION + || name == header::TRANSFER_ENCODING + || name == header::UPGRADE + || name == "keep-alive" + || name == "proxy-connection" + { + tracing::trace!("load_hpack; connection level header"); + malformed = true; + } else if name == header::TE && value != "trailers" { + tracing::trace!( + "load_hpack; TE header not set to trailers; val={:?}", + value + ); + malformed = true; + } else { + reg = true; + + headers_size += decoded_header_size(name.as_str().len(), value.len()); + if headers_size < max_header_list_size { + self.fields.append(name, value); + } else if !self.is_over_size { + tracing::trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } + } + } + Authority(v) => set_pseudo!(authority, v), + Method(v) => set_pseudo!(method, v), + Scheme(v) => set_pseudo!(scheme, v), + Path(v) => set_pseudo!(path, v), + Protocol(v) => set_pseudo!(protocol, v), + Status(v) => set_pseudo!(status, v), + } + }); + + if let Err(e) = res { + tracing::trace!("hpack decoding error; err={:?}", e); + return Err(e.into()); + } + + if malformed { + tracing::trace!("malformed message"); + return Err(Error::MalformedMessage); + } + + Ok(()) + } + + fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { + let mut hpack = BytesMut::new(); + let headers = Iter { + pseudo: Some(self.pseudo), + fields: self.fields.into_iter(), + }; + + encoder.encode(headers, &mut hpack); + + EncodingHeaderBlock { + hpack: hpack.freeze(), + } + } + + /// Calculates the size of the currently decoded header list. + /// + /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE + /// + /// > The value is based on the uncompressed size of header fields, + /// > including the length of the name and value in octets plus an + /// > overhead of 32 octets for each header field. + fn calculate_header_list_size(&self) -> usize { + macro_rules! pseudo_size { + ($name:ident) => {{ + self.pseudo + .$name + .as_ref() + .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) + .unwrap_or(0) + }}; + } + + pseudo_size!(method) + + pseudo_size!(scheme) + + pseudo_size!(status) + + pseudo_size!(authority) + + pseudo_size!(path) + + self + .fields + .iter() + .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) + .sum::<usize>() + } +} + +fn decoded_header_size(name: usize, value: usize) -> usize { + name + value + 32 +} + +#[cfg(test)] +mod test { + use std::iter::FromIterator; + + use http::HeaderValue; + + use super::*; + use crate::frame; + use crate::hpack::{huffman, Encoder}; + + #[test] + fn test_nameless_header_at_resume() { + let mut encoder = Encoder::default(); + let mut dst = BytesMut::new(); + + let headers = Headers::new( + StreamId::ZERO, + Default::default(), + HeaderMap::from_iter(vec![ + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("world"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("zomg"), + ), + ( + HeaderName::from_static("hello"), + HeaderValue::from_static("sup"), + ), + ]), + ); + + let continuation = headers + .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) + .unwrap(); + + assert_eq!(17, dst.len()); + assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); + assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); + assert_eq!("hello", huff_decode(&dst[11..15])); + assert_eq!(0x80 | 4, dst[15]); + + let mut world = dst[16..17].to_owned(); + + dst.clear(); + + assert!(continuation + .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) + .is_none()); + + world.extend_from_slice(&dst[9..12]); + assert_eq!("world", huff_decode(&world)); + + assert_eq!(24, dst.len()); + assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); + + // // Next is not indexed + assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); + assert_eq!("zomg", huff_decode(&dst[15..18])); + assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); + assert_eq!("sup", huff_decode(&dst[21..])); + } + + fn huff_decode(src: &[u8]) -> BytesMut { + let mut buf = BytesMut::new(); + huffman::decode(src, &mut buf).unwrap() + } +} |