use std::fmt; use std::io::IoSlice; use bytes::buf::{Chain, Take}; use bytes::Buf; use tracing::trace; use super::io::WriteBuf; type StaticBuf = &'static [u8]; /// Encoders to handle different Transfer-Encodings. #[derive(Debug, Clone, PartialEq)] pub(crate) struct Encoder { kind: Kind, is_last: bool, } #[derive(Debug)] pub(crate) struct EncodedBuf { kind: BufKind, } #[derive(Debug)] pub(crate) struct NotEof(u64); #[derive(Debug, PartialEq, Clone)] enum Kind { /// An Encoder for when Transfer-Encoding includes `chunked`. Chunked, /// An Encoder for when Content-Length is set. /// /// Enforces that the body is not longer than the Content-Length header. Length(u64), /// An Encoder for when neither Content-Length nor Chunked encoding is set. /// /// This is mostly only used with HTTP/1.0 with a length. This kind requires /// the connection to be closed when the body is finished. #[cfg(feature = "server")] CloseDelimited, } #[derive(Debug)] enum BufKind { Exact(B), Limited(Take), Chunked(Chain, StaticBuf>), ChunkedEnd(StaticBuf), } impl Encoder { fn new(kind: Kind) -> Encoder { Encoder { kind, is_last: false, } } pub(crate) fn chunked() -> Encoder { Encoder::new(Kind::Chunked) } pub(crate) fn length(len: u64) -> Encoder { Encoder::new(Kind::Length(len)) } #[cfg(feature = "server")] pub(crate) fn close_delimited() -> Encoder { Encoder::new(Kind::CloseDelimited) } pub(crate) fn is_eof(&self) -> bool { matches!(self.kind, Kind::Length(0)) } #[cfg(feature = "server")] pub(crate) fn set_last(mut self, is_last: bool) -> Self { self.is_last = is_last; self } pub(crate) fn is_last(&self) -> bool { self.is_last } pub(crate) fn is_close_delimited(&self) -> bool { match self.kind { #[cfg(feature = "server")] Kind::CloseDelimited => true, _ => false, } } pub(crate) fn end(&self) -> Result>, NotEof> { match self.kind { Kind::Length(0) => Ok(None), Kind::Chunked => Ok(Some(EncodedBuf { kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), })), #[cfg(feature = "server")] Kind::CloseDelimited => Ok(None), Kind::Length(n) => Err(NotEof(n)), } } pub(crate) fn encode(&mut self, msg: B) -> EncodedBuf where B: Buf, { let len = msg.remaining(); debug_assert!(len > 0, "encode() called with empty buf"); let kind = match self.kind { Kind::Chunked => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) .chain(b"\r\n" as &'static [u8]); BufKind::Chunked(buf) } Kind::Length(ref mut remaining) => { trace!("sized write, len = {}", len); if len as u64 > *remaining { let limit = *remaining as usize; *remaining = 0; BufKind::Limited(msg.take(limit)) } else { *remaining -= len as u64; BufKind::Exact(msg) } } #[cfg(feature = "server")] Kind::CloseDelimited => { trace!("close delimited write {}B", len); BufKind::Exact(msg) } }; EncodedBuf { kind } } pub(super) fn encode_and_end(&self, msg: B, dst: &mut WriteBuf>) -> bool where B: Buf, { let len = msg.remaining(); debug_assert!(len > 0, "encode() called with empty buf"); match self.kind { Kind::Chunked => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) .chain(b"\r\n0\r\n\r\n" as &'static [u8]); dst.buffer(buf); !self.is_last } Kind::Length(remaining) => { use std::cmp::Ordering; trace!("sized write, len = {}", len); match (len as u64).cmp(&remaining) { Ordering::Equal => { dst.buffer(msg); !self.is_last } Ordering::Greater => { dst.buffer(msg.take(remaining as usize)); !self.is_last } Ordering::Less => { dst.buffer(msg); false } } } #[cfg(feature = "server")] Kind::CloseDelimited => { trace!("close delimited write {}B", len); dst.buffer(msg); false } } } /// Encodes the full body, without verifying the remaining length matches. /// /// This is used in conjunction with HttpBody::__hyper_full_data(), which /// means we can trust that the buf has the correct size (the buf itself /// was checked to make the headers). pub(super) fn danger_full_buf(self, msg: B, dst: &mut WriteBuf>) where B: Buf, { debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); debug_assert!( match self.kind { Kind::Length(len) => len == msg.remaining() as u64, _ => true, }, "danger_full_buf length mismatches" ); match self.kind { Kind::Chunked => { let len = msg.remaining(); trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) .chain(b"\r\n0\r\n\r\n" as &'static [u8]); dst.buffer(buf); } _ => { dst.buffer(msg); } } } } impl Buf for EncodedBuf where B: Buf, { #[inline] fn remaining(&self) -> usize { match self.kind { BufKind::Exact(ref b) => b.remaining(), BufKind::Limited(ref b) => b.remaining(), BufKind::Chunked(ref b) => b.remaining(), BufKind::ChunkedEnd(ref b) => b.remaining(), } } #[inline] fn chunk(&self) -> &[u8] { match self.kind { BufKind::Exact(ref b) => b.chunk(), BufKind::Limited(ref b) => b.chunk(), BufKind::Chunked(ref b) => b.chunk(), BufKind::ChunkedEnd(ref b) => b.chunk(), } } #[inline] fn advance(&mut self, cnt: usize) { match self.kind { BufKind::Exact(ref mut b) => b.advance(cnt), BufKind::Limited(ref mut b) => b.advance(cnt), BufKind::Chunked(ref mut b) => b.advance(cnt), BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), } } #[inline] fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { match self.kind { BufKind::Exact(ref b) => b.chunks_vectored(dst), BufKind::Limited(ref b) => b.chunks_vectored(dst), BufKind::Chunked(ref b) => b.chunks_vectored(dst), BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), } } } #[cfg(target_pointer_width = "32")] const USIZE_BYTES: usize = 4; #[cfg(target_pointer_width = "64")] const USIZE_BYTES: usize = 8; // each byte will become 2 hex const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; #[derive(Clone, Copy)] struct ChunkSize { bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], pos: u8, len: u8, } impl ChunkSize { fn new(len: usize) -> ChunkSize { use std::fmt::Write; let mut size = ChunkSize { bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], pos: 0, len: 0, }; write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); size } } impl Buf for ChunkSize { #[inline] fn remaining(&self) -> usize { (self.len - self.pos).into() } #[inline] fn chunk(&self) -> &[u8] { &self.bytes[self.pos.into()..self.len.into()] } #[inline] fn advance(&mut self, cnt: usize) { assert!(cnt <= self.remaining()); self.pos += cnt as u8; // just asserted cnt fits in u8 } } impl fmt::Debug for ChunkSize { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ChunkSize") .field("bytes", &&self.bytes[..self.len.into()]) .field("pos", &self.pos) .finish() } } impl fmt::Write for ChunkSize { fn write_str(&mut self, num: &str) -> fmt::Result { use std::io::Write; (&mut self.bytes[self.len.into()..]) .write_all(num.as_bytes()) .expect("&mut [u8].write() cannot error"); self.len += num.len() as u8; // safe because bytes is never bigger than 256 Ok(()) } } impl From for EncodedBuf { fn from(buf: B) -> Self { EncodedBuf { kind: BufKind::Exact(buf), } } } impl From> for EncodedBuf { fn from(buf: Take) -> Self { EncodedBuf { kind: BufKind::Limited(buf), } } } impl From, StaticBuf>> for EncodedBuf { fn from(buf: Chain, StaticBuf>) -> Self { EncodedBuf { kind: BufKind::Chunked(buf), } } } impl fmt::Display for NotEof { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "early end, expected {} more bytes", self.0) } } impl std::error::Error for NotEof {} #[cfg(test)] mod tests { use bytes::BufMut; use super::super::io::Cursor; use super::Encoder; #[test] fn chunked() { let mut encoder = Encoder::chunked(); let mut dst = Vec::new(); let msg1 = b"foo bar".as_ref(); let buf1 = encoder.encode(msg1); dst.put(buf1); assert_eq!(dst, b"7\r\nfoo bar\r\n"); let msg2 = b"baz quux herp".as_ref(); let buf2 = encoder.encode(msg2); dst.put(buf2); assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); let end = encoder.end::>>().unwrap().unwrap(); dst.put(end); assert_eq!( dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() ); } #[test] fn length() { let max_len = 8; let mut encoder = Encoder::length(max_len as u64); let mut dst = Vec::new(); let msg1 = b"foo bar".as_ref(); let buf1 = encoder.encode(msg1); dst.put(buf1); assert_eq!(dst, b"foo bar"); assert!(!encoder.is_eof()); encoder.end::<()>().unwrap_err(); let msg2 = b"baz".as_ref(); let buf2 = encoder.encode(msg2); dst.put(buf2); assert_eq!(dst.len(), max_len); assert_eq!(dst, b"foo barb"); assert!(encoder.is_eof()); assert!(encoder.end::<()>().unwrap().is_none()); } #[test] fn eof() { let mut encoder = Encoder::close_delimited(); let mut dst = Vec::new(); let msg1 = b"foo bar".as_ref(); let buf1 = encoder.encode(msg1); dst.put(buf1); assert_eq!(dst, b"foo bar"); assert!(!encoder.is_eof()); encoder.end::<()>().unwrap(); let msg2 = b"baz".as_ref(); let buf2 = encoder.encode(msg2); dst.put(buf2); assert_eq!(dst, b"foo barbaz"); assert!(!encoder.is_eof()); encoder.end::<()>().unwrap(); } }