use std::marker::Unpin; use std::{cmp, io}; use bytes::{Buf, Bytes}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{task, Pin, Poll}; /// Combine a buffer with an IO, rewinding reads to use the buffer. #[derive(Debug)] pub(crate) struct Rewind { pre: Option, inner: T, } impl Rewind { #[cfg(any(all(feature = "http2", feature = "server"), test))] pub(crate) fn new(io: T) -> Self { Rewind { pre: None, inner: io, } } pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { Rewind { pre: Some(buf), inner: io, } } #[cfg(any(all(feature = "http1", feature = "http2", feature = "server"), test))] pub(crate) fn rewind(&mut self, bs: Bytes) { debug_assert!(self.pre.is_none()); self.pre = Some(bs); } pub(crate) fn into_inner(self) -> (T, Bytes) { (self.inner, self.pre.unwrap_or_else(Bytes::new)) } // pub(crate) fn get_mut(&mut self) -> &mut T { // &mut self.inner // } } impl AsyncRead for Rewind where T: AsyncRead + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if let Some(mut prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. if !prefix.is_empty() { let copy_len = cmp::min(prefix.len(), buf.remaining()); // TODO: There should be a way to do following two lines cleaner... buf.put_slice(&prefix[..copy_len]); prefix.advance(copy_len); // Put back what's left if !prefix.is_empty() { self.pre = Some(prefix); } return Poll::Ready(Ok(())); } } Pin::new(&mut self.inner).poll_read(cx, buf) } } impl AsyncWrite for Rewind where T: AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.inner).poll_write(cx, buf) } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } fn is_write_vectored(&self) -> bool { self.inner.is_write_vectored() } } #[cfg(test)] mod tests { // FIXME: re-implement tests with `async/await`, this import should // trigger a warning to remind us use super::Rewind; use bytes::Bytes; use tokio::io::AsyncReadExt; #[tokio::test] async fn partial_rewind() { let underlying = [104, 101, 108, 108, 111]; let mock = tokio_test::io::Builder::new().read(&underlying).build(); let mut stream = Rewind::new(mock); // Read off some bytes, ensure we filled o1 let mut buf = [0; 2]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. stream.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); // At this point we should have read everything that was in the MockStream assert_eq!(&buf, &underlying); } #[tokio::test] async fn full_rewind() { let underlying = [104, 101, 108, 108, 111]; let mock = tokio_test::io::Builder::new().read(&underlying).build(); let mut stream = Rewind::new(mock); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. stream.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); } }