use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use crate::codec::framed::{Fuse, ProjectFuse}; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_core::{ready, Stream}; use futures_sink::Sink; use log::trace; use pin_project_lite::pin_project; use std::fmt; use std::io::{self, BufRead, Read}; use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A `Sink` of frames encoded to an `AsyncWrite`. pub struct FramedWrite { #[pin] inner: FramedWrite2>, } } pin_project! { pub(crate) struct FramedWrite2 { #[pin] inner: T, buffer: BytesMut, } } const INITIAL_CAPACITY: usize = 8 * 1024; const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; impl FramedWrite where T: AsyncWrite, E: Encoder, { /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite { FramedWrite { inner: framed_write2(Fuse { io: inner, codec: encoder, }), } } } impl FramedWrite { /// Returns a reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner.io } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner.io } /// Consumes the `FramedWrite`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner.io } /// Returns a reference to the underlying decoder. pub fn encoder(&self) -> &E { &self.inner.inner.codec } /// Returns a mutable reference to the underlying decoder. pub fn encoder_mut(&mut self) -> &mut E { &mut self.inner.inner.codec } } // This impl just defers to the underlying FramedWrite2 impl Sink for FramedWrite where T: AsyncWrite, E: Encoder, E::Error: From, { type Error = E::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } impl Stream for FramedWrite where T: Stream, { type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() .inner .project() .inner .project() .io .poll_next(cx) } } impl fmt::Debug for FramedWrite where T: fmt::Debug, U: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") .field("inner", &self.inner.get_ref().io) .field("encoder", &self.inner.get_ref().codec) .field("buffer", &self.inner.buffer) .finish() } } // ===== impl FramedWrite2 ===== pub(crate) fn framed_write2(inner: T) -> FramedWrite2 { FramedWrite2 { inner, buffer: BytesMut::with_capacity(INITIAL_CAPACITY), } } pub(crate) fn framed_write2_with_buffer(inner: T, mut buf: BytesMut) -> FramedWrite2 { if buf.capacity() < INITIAL_CAPACITY { let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); buf.reserve(bytes_to_reserve); } FramedWrite2 { inner, buffer: buf } } impl FramedWrite2 { pub(crate) fn get_ref(&self) -> &T { &self.inner } pub(crate) fn into_inner(self) -> T { self.inner } pub(crate) fn into_parts(self) -> (T, BytesMut) { (self.inner, self.buffer) } pub(crate) fn get_mut(&mut self) -> &mut T { &mut self.inner } } impl Sink for FramedWrite2 where T: ProjectFuse + AsyncWrite, T::Codec: Encoder, { type Error = ::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's // *still* over 8KiB, then apply backpressure (reject the send). if self.buffer.len() >= BACKPRESSURE_BOUNDARY { match self.as_mut().poll_flush(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Ready(Ok(())) => (), }; if self.buffer.len() >= BACKPRESSURE_BOUNDARY { return Poll::Pending; } } Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { let mut pinned = self.project(); pinned .inner .project() .codec .encode(item, &mut pinned.buffer)?; Ok(()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("flushing framed transport"); let mut pinned = self.project(); while !pinned.buffer.is_empty() { trace!("writing; remaining={}", pinned.buffer.len()); let buf = &pinned.buffer; let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; if n == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "failed to \ write frame to transport", ) .into())); } // TODO: Add a way to `bytes` to do this w/o returning the drained data. let _ = pinned.buffer.split_to(n); } // Try flushing the underlying IO ready!(pinned.inner.poll_flush(cx))?; trace!("framed transport flushed"); Poll::Ready(Ok(())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; ready!(self.project().inner.poll_shutdown(cx))?; Poll::Ready(Ok(())) } } impl Decoder for FramedWrite2 { type Item = T::Item; type Error = T::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, T::Error> { self.inner.decode(src) } fn decode_eof(&mut self, src: &mut BytesMut) -> Result, T::Error> { self.inner.decode_eof(src) } } impl Read for FramedWrite2 { fn read(&mut self, dst: &mut [u8]) -> io::Result { self.inner.read(dst) } } impl BufRead for FramedWrite2 { fn fill_buf(&mut self) -> io::Result<&[u8]> { self.inner.fill_buf() } fn consume(&mut self, amt: usize) { self.inner.consume(amt) } } impl AsyncRead for FramedWrite2 { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { self.inner.prepare_uninitialized_buffer(buf) } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { self.project().inner.poll_read(cx, buf) } } impl AsyncBufRead for FramedWrite2 { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { self.project().inner.consume(amt) } } impl ProjectFuse for FramedWrite2 where T: ProjectFuse, { type Io = T::Io; type Codec = T::Codec; fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec> { self.project().inner.project() } }