diff options
Diffstat (limited to 'third_party/rust/h2/src/codec/framed_write.rs')
-rw-r--r-- | third_party/rust/h2/src/codec/framed_write.rs | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/codec/framed_write.rs b/third_party/rust/h2/src/codec/framed_write.rs new file mode 100644 index 0000000000..4b1b4accc4 --- /dev/null +++ b/third_party/rust/h2/src/codec/framed_write.rs @@ -0,0 +1,374 @@ +use crate::codec::UserError; +use crate::codec::UserError::*; +use crate::frame::{self, Frame, FrameSize}; +use crate::hpack; + +use bytes::{Buf, BufMut, BytesMut}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::io::{self, Cursor, IoSlice}; + +// A macro to get around a method needing to borrow &mut self +macro_rules! limited_write_buf { + ($self:expr) => {{ + let limit = $self.max_frame_size() + frame::HEADER_LEN; + $self.buf.get_mut().limit(limit) + }}; +} + +#[derive(Debug)] +pub struct FramedWrite<T, B> { + /// Upstream `AsyncWrite` + inner: T, + + encoder: Encoder<B>, +} + +#[derive(Debug)] +struct Encoder<B> { + /// HPACK encoder + hpack: hpack::Encoder, + + /// Write buffer + /// + /// TODO: Should this be a ring buffer? + buf: Cursor<BytesMut>, + + /// Next frame to encode + next: Option<Next<B>>, + + /// Last data frame + last_data_frame: Option<frame::Data<B>>, + + /// Max frame size, this is specified by the peer + max_frame_size: FrameSize, + + /// Whether or not the wrapped `AsyncWrite` supports vectored IO. + is_write_vectored: bool, +} + +#[derive(Debug)] +enum Next<B> { + Data(frame::Data<B>), + Continuation(frame::Continuation), +} + +/// Initialize the connection with this amount of write buffer. +/// +/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS +/// frame that big. +const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; + +/// Min buffer required to attempt to write a frame +const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD; + +/// Chain payloads bigger than this. The remote will never advertise a max frame +/// size less than this (well, the spec says the max frame size can't be less +/// than 16kb, so not even close). +const CHAIN_THRESHOLD: usize = 256; + +// TODO: Make generic +impl<T, B> FramedWrite<T, B> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + pub fn new(inner: T) -> FramedWrite<T, B> { + let is_write_vectored = inner.is_write_vectored(); + FramedWrite { + inner, + encoder: Encoder { + hpack: hpack::Encoder::default(), + buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), + next: None, + last_data_frame: None, + max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, + is_write_vectored, + }, + } + } + + /// Returns `Ready` when `send` is able to accept a frame + /// + /// Calling this function may result in the current contents of the buffer + /// to be flushed to `T`. + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + if !self.encoder.has_capacity() { + // Try flushing + ready!(self.flush(cx))?; + + if !self.encoder.has_capacity() { + return Poll::Pending; + } + } + + Poll::Ready(Ok(())) + } + + /// Buffer a frame. + /// + /// `poll_ready` must be called first to ensure that a frame may be + /// accepted. + pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { + self.encoder.buffer(item) + } + + /// Flush buffered data to the wire + pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + let span = tracing::trace_span!("FramedWrite::flush"); + let _e = span.enter(); + + loop { + while !self.encoder.is_empty() { + match self.encoder.next { + Some(Next::Data(ref mut frame)) => { + tracing::trace!(queued_data_frame = true); + let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut buf, + cx, + ))? + } + _ => { + tracing::trace!(queued_data_frame = false); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut self.encoder.buf, + cx, + ))? + } + } + } + + match self.encoder.unset_frame() { + ControlFlow::Continue => (), + ControlFlow::Break => break, + } + } + + tracing::trace!("flushing buffer"); + // Flush the upstream + ready!(Pin::new(&mut self.inner).poll_flush(cx))?; + + Poll::Ready(Ok(())) + } + + /// Close the codec + pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { + ready!(self.flush(cx))?; + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +fn write<T, B>( + writer: &mut T, + is_write_vectored: bool, + buf: &mut B, + cx: &mut Context<'_>, +) -> Poll<io::Result<()>> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + // TODO(eliza): when tokio-util 0.5.1 is released, this + // could just use `poll_write_buf`... + const MAX_IOVS: usize = 64; + let n = if is_write_vectored { + let mut bufs = [IoSlice::new(&[]); MAX_IOVS]; + let cnt = buf.chunks_vectored(&mut bufs); + ready!(Pin::new(writer).poll_write_vectored(cx, &bufs[..cnt]))? + } else { + ready!(Pin::new(writer).poll_write(cx, buf.chunk()))? + }; + buf.advance(n); + Ok(()).into() +} + +#[must_use] +enum ControlFlow { + Continue, + Break, +} + +impl<B> Encoder<B> +where + B: Buf, +{ + fn unset_frame(&mut self) -> ControlFlow { + // Clear internal buffer + self.buf.set_position(0); + self.buf.get_mut().clear(); + + // The data frame has been written, so unset it + match self.next.take() { + Some(Next::Data(frame)) => { + self.last_data_frame = Some(frame); + debug_assert!(self.is_empty()); + ControlFlow::Break + } + Some(Next::Continuation(frame)) => { + // Buffer the continuation frame, then try to write again + let mut buf = limited_write_buf!(self); + if let Some(continuation) = frame.encode(&mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + ControlFlow::Continue + } + None => ControlFlow::Break, + } + } + + fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { + // Ensure that we have enough capacity to accept the write. + assert!(self.has_capacity()); + let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); + let _e = span.enter(); + + tracing::debug!(frame = ?item, "send"); + + match item { + Frame::Data(mut v) => { + // Ensure that the payload is not greater than the max frame. + let len = v.payload().remaining(); + + if len > self.max_frame_size() { + return Err(PayloadTooBig); + } + + if len >= CHAIN_THRESHOLD { + let head = v.head(); + + // Encode the frame head to the buffer + head.encode(len, self.buf.get_mut()); + + // Save the data frame + self.next = Some(Next::Data(v)); + } else { + v.encode_chunk(self.buf.get_mut()); + + // The chunk has been fully encoded, so there is no need to + // keep it around + assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); + + // Save off the last frame... + self.last_data_frame = Some(v); + } + } + Frame::Headers(v) => { + let mut buf = limited_write_buf!(self); + if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + } + Frame::PushPromise(v) => { + let mut buf = limited_write_buf!(self); + if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { + self.next = Some(Next::Continuation(continuation)); + } + } + Frame::Settings(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded settings"); + } + Frame::GoAway(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); + } + Frame::Ping(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded ping"); + } + Frame::WindowUpdate(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); + } + + Frame::Priority(_) => { + /* + v.encode(self.buf.get_mut()); + tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); + */ + unimplemented!(); + } + Frame::Reset(v) => { + v.encode(self.buf.get_mut()); + tracing::trace!(rem = self.buf.remaining(), "encoded reset"); + } + } + + Ok(()) + } + + fn has_capacity(&self) -> bool { + self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY + } + + fn is_empty(&self) -> bool { + match self.next { + Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), + _ => !self.buf.has_remaining(), + } + } +} + +impl<B> Encoder<B> { + fn max_frame_size(&self) -> usize { + self.max_frame_size as usize + } +} + +impl<T, B> FramedWrite<T, B> { + /// Returns the max frame size that can be sent + pub fn max_frame_size(&self) -> usize { + self.encoder.max_frame_size() + } + + /// Set the peer's max frame size. + pub fn set_max_frame_size(&mut self, val: usize) { + assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); + self.encoder.max_frame_size = val as FrameSize; + } + + /// Set the peer's header table size. + pub fn set_header_table_size(&mut self, val: usize) { + self.encoder.hpack.update_max_size(val); + } + + /// Retrieve the last data frame that has been sent + pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { + self.encoder.last_data_frame.take() + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +// We never project the Pin to `B`. +impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} + +#[cfg(feature = "unstable")] +mod unstable { + use super::*; + + impl<T, B> FramedWrite<T, B> { + pub fn get_ref(&self) -> &T { + &self.inner + } + } +} |