use crate::codec::UserError; use crate::codec::UserError::*; use crate::frame::{self, Frame, FrameSize}; use crate::hpack; use bytes::{Buf, BufMut, BytesMut}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_util::io::poll_write_buf; use std::io::{self, Cursor}; // A macro to get around a method needing to borrow &mut self macro_rules! limited_write_buf { ($self:expr) => {{ let limit = $self.max_frame_size() + frame::HEADER_LEN; $self.buf.get_mut().limit(limit) }}; } #[derive(Debug)] pub struct FramedWrite { /// Upstream `AsyncWrite` inner: T, encoder: Encoder, } #[derive(Debug)] struct Encoder { /// HPACK encoder hpack: hpack::Encoder, /// Write buffer /// /// TODO: Should this be a ring buffer? buf: Cursor, /// Next frame to encode next: Option>, /// Last data frame last_data_frame: Option>, /// Max frame size, this is specified by the peer max_frame_size: FrameSize, /// Chain payloads bigger than this. chain_threshold: usize, /// Min buffer required to attempt to write a frame min_buffer_capacity: usize, } #[derive(Debug)] enum Next { Data(frame::Data), Continuation(frame::Continuation), } /// Initialize the connection with this amount of write buffer. /// /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS /// frame that big. const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; /// Chain payloads bigger than this when vectored I/O is enabled. The remote /// will never advertise a max frame size less than this (well, the spec says /// the max frame size can't be less than 16kb, so not even close). const CHAIN_THRESHOLD: usize = 256; /// Chain payloads bigger than this when vectored I/O is **not** enabled. /// A larger value in this scenario will reduce the number of small and /// fragmented data being sent, and hereby improve the throughput. const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; // TODO: Make generic impl FramedWrite where T: AsyncWrite + Unpin, B: Buf, { pub fn new(inner: T) -> FramedWrite { let chain_threshold = if inner.is_write_vectored() { CHAIN_THRESHOLD } else { CHAIN_THRESHOLD_WITHOUT_VECTORED_IO }; FramedWrite { inner, encoder: Encoder { hpack: hpack::Encoder::default(), buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), next: None, last_data_frame: None, max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, chain_threshold, min_buffer_capacity: chain_threshold + frame::HEADER_LEN, }, } } /// Returns `Ready` when `send` is able to accept a frame /// /// Calling this function may result in the current contents of the buffer /// to be flushed to `T`. pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { if !self.encoder.has_capacity() { // Try flushing ready!(self.flush(cx))?; if !self.encoder.has_capacity() { return Poll::Pending; } } Poll::Ready(Ok(())) } /// Buffer a frame. /// /// `poll_ready` must be called first to ensure that a frame may be /// accepted. pub fn buffer(&mut self, item: Frame) -> Result<(), UserError> { self.encoder.buffer(item) } /// Flush buffered data to the wire pub fn flush(&mut self, cx: &mut Context) -> Poll> { let span = tracing::trace_span!("FramedWrite::flush"); let _e = span.enter(); loop { while !self.encoder.is_empty() { match self.encoder.next { Some(Next::Data(ref mut frame)) => { tracing::trace!(queued_data_frame = true); let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? } _ => { tracing::trace!(queued_data_frame = false); ready!(poll_write_buf( Pin::new(&mut self.inner), cx, &mut self.encoder.buf ))? } }; } match self.encoder.unset_frame() { ControlFlow::Continue => (), ControlFlow::Break => break, } } tracing::trace!("flushing buffer"); // Flush the upstream ready!(Pin::new(&mut self.inner).poll_flush(cx))?; Poll::Ready(Ok(())) } /// Close the codec pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { ready!(self.flush(cx))?; Pin::new(&mut self.inner).poll_shutdown(cx) } } #[must_use] enum ControlFlow { Continue, Break, } impl Encoder where B: Buf, { fn unset_frame(&mut self) -> ControlFlow { // Clear internal buffer self.buf.set_position(0); self.buf.get_mut().clear(); // The data frame has been written, so unset it match self.next.take() { Some(Next::Data(frame)) => { self.last_data_frame = Some(frame); debug_assert!(self.is_empty()); ControlFlow::Break } Some(Next::Continuation(frame)) => { // Buffer the continuation frame, then try to write again let mut buf = limited_write_buf!(self); if let Some(continuation) = frame.encode(&mut buf) { self.next = Some(Next::Continuation(continuation)); } ControlFlow::Continue } None => ControlFlow::Break, } } fn buffer(&mut self, item: Frame) -> Result<(), UserError> { // Ensure that we have enough capacity to accept the write. assert!(self.has_capacity()); let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); let _e = span.enter(); tracing::debug!(frame = ?item, "send"); match item { Frame::Data(mut v) => { // Ensure that the payload is not greater than the max frame. let len = v.payload().remaining(); if len > self.max_frame_size() { return Err(PayloadTooBig); } if len >= self.chain_threshold { let head = v.head(); // Encode the frame head to the buffer head.encode(len, self.buf.get_mut()); if self.buf.get_ref().remaining() < self.chain_threshold { let extra_bytes = self.chain_threshold - self.buf.remaining(); self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); } // Save the data frame self.next = Some(Next::Data(v)); } else { v.encode_chunk(self.buf.get_mut()); // The chunk has been fully encoded, so there is no need to // keep it around assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); // Save off the last frame... self.last_data_frame = Some(v); } } Frame::Headers(v) => { let mut buf = limited_write_buf!(self); if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { self.next = Some(Next::Continuation(continuation)); } } Frame::PushPromise(v) => { let mut buf = limited_write_buf!(self); if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { self.next = Some(Next::Continuation(continuation)); } } Frame::Settings(v) => { v.encode(self.buf.get_mut()); tracing::trace!(rem = self.buf.remaining(), "encoded settings"); } Frame::GoAway(v) => { v.encode(self.buf.get_mut()); tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); } Frame::Ping(v) => { v.encode(self.buf.get_mut()); tracing::trace!(rem = self.buf.remaining(), "encoded ping"); } Frame::WindowUpdate(v) => { v.encode(self.buf.get_mut()); tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); } Frame::Priority(_) => { /* v.encode(self.buf.get_mut()); tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); */ unimplemented!(); } Frame::Reset(v) => { v.encode(self.buf.get_mut()); tracing::trace!(rem = self.buf.remaining(), "encoded reset"); } } Ok(()) } fn has_capacity(&self) -> bool { self.next.is_none() && (self.buf.get_ref().capacity() - self.buf.get_ref().len() >= self.min_buffer_capacity) } fn is_empty(&self) -> bool { match self.next { Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), _ => !self.buf.has_remaining(), } } } impl Encoder { fn max_frame_size(&self) -> usize { self.max_frame_size as usize } } impl FramedWrite { /// Returns the max frame size that can be sent pub fn max_frame_size(&self) -> usize { self.encoder.max_frame_size() } /// Set the peer's max frame size. pub fn set_max_frame_size(&mut self, val: usize) { assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); self.encoder.max_frame_size = val as FrameSize; } /// Set the peer's header table size. pub fn set_header_table_size(&mut self, val: usize) { self.encoder.hpack.update_max_size(val); } /// Retrieve the last data frame that has been sent pub fn take_last_data_frame(&mut self) -> Option> { self.encoder.last_data_frame.take() } pub fn get_mut(&mut self) -> &mut T { &mut self.inner } } impl AsyncRead for FramedWrite { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } } // We never project the Pin to `B`. impl Unpin for FramedWrite {} #[cfg(feature = "unstable")] mod unstable { use super::*; impl FramedWrite { pub fn get_ref(&self) -> &T { &self.inner } } }