use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use futures_core::Stream; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_core::ready; use futures_sink::Sink; use pin_project_lite::pin_project; use std::borrow::{Borrow, BorrowMut}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tracing::trace; pin_project! { #[derive(Debug)] pub(crate) struct FramedImpl { #[pin] pub(crate) inner: T, pub(crate) state: State, pub(crate) codec: U, } } const INITIAL_CAPACITY: usize = 8 * 1024; const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; #[derive(Debug)] pub(crate) struct ReadFrame { pub(crate) eof: bool, pub(crate) is_readable: bool, pub(crate) buffer: BytesMut, pub(crate) has_errored: bool, } pub(crate) struct WriteFrame { pub(crate) buffer: BytesMut, } #[derive(Default)] pub(crate) struct RWFrames { pub(crate) read: ReadFrame, pub(crate) write: WriteFrame, } impl Default for ReadFrame { fn default() -> Self { Self { eof: false, is_readable: false, buffer: BytesMut::with_capacity(INITIAL_CAPACITY), has_errored: false, } } } impl Default for WriteFrame { fn default() -> Self { Self { buffer: BytesMut::with_capacity(INITIAL_CAPACITY), } } } impl From for ReadFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer, is_readable: size > 0, eof: false, has_errored: false, } } } impl From for WriteFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer } } } impl Borrow for RWFrames { fn borrow(&self) -> &ReadFrame { &self.read } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut ReadFrame { &mut self.read } } impl Borrow for RWFrames { fn borrow(&self) -> &WriteFrame { &self.write } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut WriteFrame { &mut self.write } } impl Stream for FramedImpl where T: AsyncRead, U: Decoder, R: BorrowMut, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_read_buf; let mut pinned = self.project(); let state: &mut ReadFrame = pinned.state.borrow_mut(); // The following loops implements a state machine with each state corresponding // to a combination of the `is_readable` and `eof` flags. States persist across // loop entries and most state transitions occur with a return. // // The initial state is `reading`. // // | state | eof | is_readable | has_errored | // |---------|-------|-------------|-------------| // | reading | false | false | false | // | framing | false | true | false | // | pausing | true | true | false | // | paused | true | false | false | // | errored | | | true | // `decode_eof` returns Err // ┌────────────────────────────────────────────────────────┐ // `decode_eof` returns │ │ // `Ok(Some)` │ │ // ┌─────┐ │ `decode_eof` returns After returning │ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ // Pending read │ │ │ │ │ │ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ // │ │ │ ┌──────┐ │ Pending │ │ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ // └──┬─▲────┘ └─────┬──┬┘ │ │ // │ │ │ │ `decode` returns Err │ │ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ // │ read returns Err │ // └────────────────────────────────────────────────────────────────────────────────────────────┘ loop { // Return `None` if we have encountered an error from the underlying decoder // See: https://github.com/tokio-rs/tokio/issues/3976 if state.has_errored { // preparing has_errored -> paused trace!("Returning None and setting paused"); state.is_readable = false; state.has_errored = false; return Poll::Ready(None); } // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", // i.e. it _might_ contain data consumable as a frame or closing frame. // Both signal that there is no such data by returning `None`. // // If `decode` couldn't read a frame and the upstream source has returned eof, // `decode_eof` will attempt to decode the remaining bytes as closing frames. // // If the underlying AsyncRead is resumable, we may continue after an EOF, // but must finish emitting all of it's associated `decode_eof` frames. // Furthermore, we don't want to emit any `decode_eof` frames on retried // reads after an EOF unless we've actually read more data. if state.is_readable { // pausing or framing if state.eof { // pausing let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { trace!("Got an error, going to errored state"); state.has_errored = true; err })?; if frame.is_none() { state.is_readable = false; // prepare pausing -> paused } // implicit pausing -> pausing or pausing -> paused return Poll::Ready(frame.map(Ok)); } // framing trace!("attempting to decode a frame"); if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { trace!("Got an error, going to errored state"); state.has_errored = true; op })? { trace!("frame decoded from buffer"); // implicit framing -> framing return Poll::Ready(Some(Ok(frame))); } // framing -> reading state.is_readable = false; } // reading or paused // If we can't build a frame yet, try to read more data and try again. // Make sure we've got room for at least one byte to read to ensure // that we don't get a spurious 0 that looks like EOF. state.buffer.reserve(1); let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |err| { trace!("Got an error, going to errored state"); state.has_errored = true; err }, )? { Poll::Ready(ct) => ct, // implicit reading -> reading or implicit paused -> paused Poll::Pending => return Poll::Pending, }; if bytect == 0 { if state.eof { // We're already at an EOF, and since we've reached this path // we're also not readable. This implies that we've already finished // our `decode_eof` handling, so we can simply return `None`. // implicit paused -> paused return Poll::Ready(None); } // prepare reading -> paused state.eof = true; } else { // prepare paused -> framing or noop reading -> framing state.eof = false; } // paused -> framing or reading -> framing or reading -> pausing state.is_readable = true; } } } impl Sink for FramedImpl where T: AsyncWrite, U: Encoder, U::Error: From, W: BorrowMut, { type Error = U::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { self.as_mut().poll_flush(cx) } else { Poll::Ready(Ok(())) } } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { let pinned = self.project(); pinned .codec .encode(item, &mut pinned.state.borrow_mut().buffer)?; Ok(()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_write_buf; trace!("flushing framed transport"); let mut pinned = self.project(); while !pinned.state.borrow_mut().buffer.is_empty() { let WriteFrame { buffer } = pinned.state.borrow_mut(); trace!(remaining = buffer.len(), "writing;"); let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; if n == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "failed to \ write frame to transport", ) .into())); } } // Try flushing the underlying IO ready!(pinned.inner.poll_flush(cx))?; trace!("framed transport flushed"); Poll::Ready(Ok(())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; ready!(self.project().inner.poll_shutdown(cx))?; Poll::Ready(Ok(())) } }