summaryrefslogtreecommitdiffstats
path: root/third_party/rust/h2/src/codec/framed_write.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/h2/src/codec/framed_write.rs')
-rw-r--r--third_party/rust/h2/src/codec/framed_write.rs362
1 files changed, 362 insertions, 0 deletions
diff --git a/third_party/rust/h2/src/codec/framed_write.rs b/third_party/rust/h2/src/codec/framed_write.rs
new file mode 100644
index 0000000000..c88af02daf
--- /dev/null
+++ b/third_party/rust/h2/src/codec/framed_write.rs
@@ -0,0 +1,362 @@
+use crate::codec::UserError;
+use crate::codec::UserError::*;
+use crate::frame::{self, Frame, FrameSize};
+use crate::hpack;
+
+use bytes::{Buf, BufMut, BytesMut};
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio_util::io::poll_write_buf;
+
+use std::io::{self, Cursor};
+
+// A macro to get around a method needing to borrow &mut self
+macro_rules! limited_write_buf {
+ ($self:expr) => {{
+ let limit = $self.max_frame_size() + frame::HEADER_LEN;
+ $self.buf.get_mut().limit(limit)
+ }};
+}
+
+#[derive(Debug)]
+pub struct FramedWrite<T, B> {
+ /// Upstream `AsyncWrite`
+ inner: T,
+
+ encoder: Encoder<B>,
+}
+
+#[derive(Debug)]
+struct Encoder<B> {
+ /// HPACK encoder
+ hpack: hpack::Encoder,
+
+ /// Write buffer
+ ///
+ /// TODO: Should this be a ring buffer?
+ buf: Cursor<BytesMut>,
+
+ /// Next frame to encode
+ next: Option<Next<B>>,
+
+ /// Last data frame
+ last_data_frame: Option<frame::Data<B>>,
+
+ /// Max frame size, this is specified by the peer
+ max_frame_size: FrameSize,
+
+ /// Chain payloads bigger than this.
+ chain_threshold: usize,
+
+ /// Min buffer required to attempt to write a frame
+ min_buffer_capacity: usize,
+}
+
+#[derive(Debug)]
+enum Next<B> {
+ Data(frame::Data<B>),
+ Continuation(frame::Continuation),
+}
+
+/// Initialize the connection with this amount of write buffer.
+///
+/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
+/// frame that big.
+const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
+
+/// Chain payloads bigger than this when vectored I/O is enabled. The remote
+/// will never advertise a max frame size less than this (well, the spec says
+/// the max frame size can't be less than 16kb, so not even close).
+const CHAIN_THRESHOLD: usize = 256;
+
+/// Chain payloads bigger than this when vectored I/O is **not** enabled.
+/// A larger value in this scenario will reduce the number of small and
+/// fragmented data being sent, and hereby improve the throughput.
+const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
+
+// TODO: Make generic
+impl<T, B> FramedWrite<T, B>
+where
+ T: AsyncWrite + Unpin,
+ B: Buf,
+{
+ pub fn new(inner: T) -> FramedWrite<T, B> {
+ let chain_threshold = if inner.is_write_vectored() {
+ CHAIN_THRESHOLD
+ } else {
+ CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
+ };
+ FramedWrite {
+ inner,
+ encoder: Encoder {
+ hpack: hpack::Encoder::default(),
+ buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
+ next: None,
+ last_data_frame: None,
+ max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
+ chain_threshold,
+ min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
+ },
+ }
+ }
+
+ /// Returns `Ready` when `send` is able to accept a frame
+ ///
+ /// Calling this function may result in the current contents of the buffer
+ /// to be flushed to `T`.
+ pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
+ if !self.encoder.has_capacity() {
+ // Try flushing
+ ready!(self.flush(cx))?;
+
+ if !self.encoder.has_capacity() {
+ return Poll::Pending;
+ }
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ /// Buffer a frame.
+ ///
+ /// `poll_ready` must be called first to ensure that a frame may be
+ /// accepted.
+ pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
+ self.encoder.buffer(item)
+ }
+
+ /// Flush buffered data to the wire
+ pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
+ let span = tracing::trace_span!("FramedWrite::flush");
+ let _e = span.enter();
+
+ loop {
+ while !self.encoder.is_empty() {
+ match self.encoder.next {
+ Some(Next::Data(ref mut frame)) => {
+ tracing::trace!(queued_data_frame = true);
+ let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
+ ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
+ }
+ _ => {
+ tracing::trace!(queued_data_frame = false);
+ ready!(poll_write_buf(
+ Pin::new(&mut self.inner),
+ cx,
+ &mut self.encoder.buf
+ ))?
+ }
+ };
+ }
+
+ match self.encoder.unset_frame() {
+ ControlFlow::Continue => (),
+ ControlFlow::Break => break,
+ }
+ }
+
+ tracing::trace!("flushing buffer");
+ // Flush the upstream
+ ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
+
+ Poll::Ready(Ok(()))
+ }
+
+ /// Close the codec
+ pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
+ ready!(self.flush(cx))?;
+ Pin::new(&mut self.inner).poll_shutdown(cx)
+ }
+}
+
+#[must_use]
+enum ControlFlow {
+ Continue,
+ Break,
+}
+
+impl<B> Encoder<B>
+where
+ B: Buf,
+{
+ fn unset_frame(&mut self) -> ControlFlow {
+ // Clear internal buffer
+ self.buf.set_position(0);
+ self.buf.get_mut().clear();
+
+ // The data frame has been written, so unset it
+ match self.next.take() {
+ Some(Next::Data(frame)) => {
+ self.last_data_frame = Some(frame);
+ debug_assert!(self.is_empty());
+ ControlFlow::Break
+ }
+ Some(Next::Continuation(frame)) => {
+ // Buffer the continuation frame, then try to write again
+ let mut buf = limited_write_buf!(self);
+ if let Some(continuation) = frame.encode(&mut buf) {
+ self.next = Some(Next::Continuation(continuation));
+ }
+ ControlFlow::Continue
+ }
+ None => ControlFlow::Break,
+ }
+ }
+
+ fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
+ // Ensure that we have enough capacity to accept the write.
+ assert!(self.has_capacity());
+ let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
+ let _e = span.enter();
+
+ tracing::debug!(frame = ?item, "send");
+
+ match item {
+ Frame::Data(mut v) => {
+ // Ensure that the payload is not greater than the max frame.
+ let len = v.payload().remaining();
+
+ if len > self.max_frame_size() {
+ return Err(PayloadTooBig);
+ }
+
+ if len >= self.chain_threshold {
+ let head = v.head();
+
+ // Encode the frame head to the buffer
+ head.encode(len, self.buf.get_mut());
+
+ if self.buf.get_ref().remaining() < self.chain_threshold {
+ let extra_bytes = self.chain_threshold - self.buf.remaining();
+ self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
+ }
+
+ // Save the data frame
+ self.next = Some(Next::Data(v));
+ } else {
+ v.encode_chunk(self.buf.get_mut());
+
+ // The chunk has been fully encoded, so there is no need to
+ // keep it around
+ assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
+
+ // Save off the last frame...
+ self.last_data_frame = Some(v);
+ }
+ }
+ Frame::Headers(v) => {
+ let mut buf = limited_write_buf!(self);
+ if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
+ self.next = Some(Next::Continuation(continuation));
+ }
+ }
+ Frame::PushPromise(v) => {
+ let mut buf = limited_write_buf!(self);
+ if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
+ self.next = Some(Next::Continuation(continuation));
+ }
+ }
+ Frame::Settings(v) => {
+ v.encode(self.buf.get_mut());
+ tracing::trace!(rem = self.buf.remaining(), "encoded settings");
+ }
+ Frame::GoAway(v) => {
+ v.encode(self.buf.get_mut());
+ tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
+ }
+ Frame::Ping(v) => {
+ v.encode(self.buf.get_mut());
+ tracing::trace!(rem = self.buf.remaining(), "encoded ping");
+ }
+ Frame::WindowUpdate(v) => {
+ v.encode(self.buf.get_mut());
+ tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
+ }
+
+ Frame::Priority(_) => {
+ /*
+ v.encode(self.buf.get_mut());
+ tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
+ */
+ unimplemented!();
+ }
+ Frame::Reset(v) => {
+ v.encode(self.buf.get_mut());
+ tracing::trace!(rem = self.buf.remaining(), "encoded reset");
+ }
+ }
+
+ Ok(())
+ }
+
+ fn has_capacity(&self) -> bool {
+ self.next.is_none()
+ && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
+ >= self.min_buffer_capacity)
+ }
+
+ fn is_empty(&self) -> bool {
+ match self.next {
+ Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
+ _ => !self.buf.has_remaining(),
+ }
+ }
+}
+
+impl<B> Encoder<B> {
+ fn max_frame_size(&self) -> usize {
+ self.max_frame_size as usize
+ }
+}
+
+impl<T, B> FramedWrite<T, B> {
+ /// Returns the max frame size that can be sent
+ pub fn max_frame_size(&self) -> usize {
+ self.encoder.max_frame_size()
+ }
+
+ /// Set the peer's max frame size.
+ pub fn set_max_frame_size(&mut self, val: usize) {
+ assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
+ self.encoder.max_frame_size = val as FrameSize;
+ }
+
+ /// Set the peer's header table size.
+ pub fn set_header_table_size(&mut self, val: usize) {
+ self.encoder.hpack.update_max_size(val);
+ }
+
+ /// Retrieve the last data frame that has been sent
+ pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
+ self.encoder.last_data_frame.take()
+ }
+
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner
+ }
+}
+
+impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf,
+ ) -> Poll<io::Result<()>> {
+ Pin::new(&mut self.inner).poll_read(cx, buf)
+ }
+}
+
+// We never project the Pin to `B`.
+impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
+
+#[cfg(feature = "unstable")]
+mod unstable {
+ use super::*;
+
+ impl<T, B> FramedWrite<T, B> {
+ pub fn get_ref(&self) -> &T {
+ &self.inner
+ }
+ }
+}