diff options
Diffstat (limited to 'third_party/rust/tokio-util/src')
42 files changed, 8519 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/src/cfg.rs b/third_party/rust/tokio-util/src/cfg.rs new file mode 100644 index 0000000000..4035255aff --- /dev/null +++ b/third_party/rust/tokio-util/src/cfg.rs @@ -0,0 +1,71 @@ +macro_rules! cfg_codec { + ($($item:item)*) => { + $( + #[cfg(feature = "codec")] + #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] + $item + )* + } +} + +macro_rules! cfg_compat { + ($($item:item)*) => { + $( + #[cfg(feature = "compat")] + #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(all(feature = "net", feature = "codec"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] + $item + )* + } +} + +macro_rules! cfg_io { + ($($item:item)*) => { + $( + #[cfg(feature = "io")] + #[cfg_attr(docsrs, doc(cfg(feature = "io")))] + $item + )* + } +} + +cfg_io! { + macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } + } +} + +macro_rules! cfg_rt { + ($($item:item)*) => { + $( + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + $item + )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} diff --git a/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs new file mode 100644 index 0000000000..3dbfd456b0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs @@ -0,0 +1,263 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r"; +const DEFAULT_SEQUENCE_WRITER: &[u8] = b","; +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// Decode string of bytes containing various different delimiters. +/// +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// use tokio_util::codec::{AnyDelimiterCodec, Decoder}; +/// use bytes::{BufMut, BytesMut}; +/// +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec()); +/// let buf = &mut BytesMut::new(); +/// buf.reserve(200); +/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); +/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!(None, codec.decode(buf).unwrap()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct AnyDelimiterCodec { + // Stored index of the next index to examine for the delimiter character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde}`, the method will + // only look at `de}` before returning. + next_index: usize, + + /// The maximum length for a given chunk. If `usize::MAX`, chunks will be + /// read until a delimiter character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a chunk which was over + /// the length limit? + is_discarding: bool, + + /// The bytes that are using for search during decode + seek_delimiters: Vec<u8>, + + /// The bytes that are using for encoding + sequence_writer: Vec<u8>, +} + +impl AnyDelimiterCodec { + /// Returns a `AnyDelimiterCodec` for splitting up data into chunks. + /// + /// # Note + /// + /// The returned `AnyDelimiterCodec` will not have an upper bound on the length + /// of a buffered chunk. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length() + pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec { + AnyDelimiterCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + seek_delimiters, + sequence_writer, + } + } + + /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit. + /// + /// If this is set, calls to `AnyDelimiterCodec::decode` will return a + /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that chunk until a delimiter + /// character is reached, returning `None` until the delimiter over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the chunk currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any delimiter characters, causing unbounded memory consumption. + /// + /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError + pub fn new_with_max_length( + seek_delimiters: Vec<u8>, + sequence_writer: Vec<u8>, + max_length: usize, + ) -> Self { + AnyDelimiterCodec { + max_length, + ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer) + } + } + + /// Returns the maximum chunk length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec()); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +impl Decoder for AnyDelimiterCodec { + type Item = Bytes; + type Error = AnyDelimiterCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> { + loop { + // Determine how far into the buffer we'll search for a delimiter. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| { + self.seek_delimiters + .iter() + .any(|delimiter| *b == *delimiter) + }); + + match (self.is_discarding, new_chunk_offset) { + (true, Some(offset)) => { + // If we found a new chunk, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a chunk normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a new chunk, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a new chunk. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a chunk! + let new_chunk_index = offset + self.next_index; + self.next_index = 0; + let mut chunk = buf.split_to(new_chunk_index + 1); + chunk.truncate(chunk.len() - 1); + let chunk = chunk.freeze(); + return Ok(Some(chunk)); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // new chunk, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded); + } + (false, None) => { + // We didn't find a chunk or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // return remaining data, if any + if buf.is_empty() { + None + } else { + let chunk = buf.split_to(buf.len()); + self.next_index = 0; + Some(chunk.freeze()) + } + } + }) + } +} + +impl<T> Encoder<T> for AnyDelimiterCodec +where + T: AsRef<str>, +{ + type Error = AnyDelimiterCodecError; + + fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> { + let chunk = chunk.as_ref(); + buf.reserve(chunk.len() + 1); + buf.put(chunk.as_bytes()); + buf.put(self.sequence_writer.as_ref()); + + Ok(()) + } +} + +impl Default for AnyDelimiterCodec { + fn default() -> Self { + Self::new( + DEFAULT_SEEK_DELIMITERS.to_vec(), + DEFAULT_SEQUENCE_WRITER.to_vec(), + ) + } +} + +/// An error occurred while encoding or decoding a chunk. +#[derive(Debug)] +pub enum AnyDelimiterCodecError { + /// The maximum chunk length was exceeded. + MaxChunkLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for AnyDelimiterCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AnyDelimiterCodecError::MaxChunkLengthExceeded => { + write!(f, "max chunk length exceeded") + } + AnyDelimiterCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From<io::Error> for AnyDelimiterCodecError { + fn from(e: io::Error) -> AnyDelimiterCodecError { + AnyDelimiterCodecError::Io(e) + } +} + +impl std::error::Error for AnyDelimiterCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/bytes_codec.rs b/third_party/rust/tokio-util/src/codec/bytes_codec.rs new file mode 100644 index 0000000000..ceab228b94 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/bytes_codec.rs @@ -0,0 +1,86 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{BufMut, Bytes, BytesMut}; +use std::io; + +/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// +/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// # mod hidden { +/// # #[allow(unused_imports)] +/// use tokio::fs::File; +/// # } +/// use tokio::io::AsyncRead; +/// use tokio_util::codec::{FramedRead, BytesCodec}; +/// +/// # enum File {} +/// # impl File { +/// # async fn open(_name: &str) -> Result<impl AsyncRead, std::io::Error> { +/// # use std::io::Cursor; +/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5])) +/// # } +/// # } +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let my_async_read = File::open("filename.txt").await?; +/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] +pub struct BytesCodec(()); + +impl BytesCodec { + /// Creates a new `BytesCodec` for shipping around raw bytes. + pub fn new() -> BytesCodec { + BytesCodec(()) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> { + if !buf.is_empty() { + let len = buf.len(); + Ok(Some(buf.split_to(len))) + } else { + Ok(None) + } + } +} + +impl Encoder<Bytes> for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} + +impl Encoder<BytesMut> for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} diff --git a/third_party/rust/tokio-util/src/codec/decoder.rs b/third_party/rust/tokio-util/src/codec/decoder.rs new file mode 100644 index 0000000000..c5927783d1 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/decoder.rs @@ -0,0 +1,184 @@ +use crate::codec::Framed; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use std::io; + +/// Decoding of frames via buffers. +/// +/// This trait is used when constructing an instance of [`Framed`] or +/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has +/// already been buffered in `src` and decodes the data into a stream of +/// `Self::Item` frames. +/// +/// Implementations are able to track state on `self`, which enables +/// implementing stateful streaming parsers. In many cases, though, this type +/// will simply be a unit struct (e.g. `struct HttpDecoder`). +/// +/// For some underlying data-sources, namely files and FIFOs, +/// it's possible to temporarily read 0 bytes by reaching EOF. +/// +/// In these cases `decode_eof` will be called until it signals +/// fullfillment of all closing frames by returning `Ok(None)`. +/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`] +/// will not invoke `decode` or `decode_eof` again, until data can be read +/// during a retry. +/// +/// It is up to the Decoder to keep track of a restart after an EOF, +/// and to decide how to handle such an event by, for example, +/// allowing frames to cross EOF boundaries, re-emitting opening frames, or +/// resetting the entire internal state. +/// +/// [`Framed`]: crate::codec::Framed +/// [`FramedRead`]: crate::codec::FramedRead +pub trait Decoder { + /// The type of decoded frames. + type Item; + + /// The type of unrecoverable frame decoding errors. + /// + /// If an individual message is ill-formed but can be ignored without + /// interfering with the processing of future messages, it may be more + /// useful to report the failure as an `Item`. + /// + /// `From<io::Error>` is required in the interest of making `Error` suitable + /// for returning directly from a [`FramedRead`], and to enable the default + /// implementation of `decode_eof` to yield an `io::Error` when the decoder + /// fails to consume all available data. + /// + /// Note that implementors of this trait can simply indicate `type Error = + /// io::Error` to use I/O errors as this type. + /// + /// [`FramedRead`]: crate::codec::FramedRead + type Error: From<io::Error>; + + /// Attempts to decode a frame from the provided buffer of bytes. + /// + /// This method is called by [`FramedRead`] whenever bytes are ready to be + /// parsed. The provided buffer of bytes is what's been read so far, and + /// this instance of `Decode` can determine whether an entire frame is in + /// the buffer and is ready to be returned. + /// + /// If an entire frame is available, then this instance will remove those + /// bytes from the buffer provided and return them as a decoded + /// frame. Note that removing bytes from the provided buffer doesn't always + /// necessarily copy the bytes, so this should be an efficient operation in + /// most circumstances. + /// + /// If the bytes look valid, but a frame isn't fully available yet, then + /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that + /// it needs to read some more bytes before calling this method again. + /// + /// Note that the bytes provided may be empty. If a previous call to + /// `decode` consumed all the bytes in the buffer then `decode` will be + /// called again until it returns `Ok(None)`, indicating that more bytes need to + /// be read. + /// + /// Finally, if the bytes in the buffer are malformed then an error is + /// returned indicating why. This informs [`Framed`] that the stream is now + /// corrupt and should be terminated. + /// + /// [`Framed`]: crate::codec::Framed + /// [`FramedRead`]: crate::codec::FramedRead + /// + /// # Buffer management + /// + /// Before returning from the function, implementations should ensure that + /// the buffer has appropriate capacity in anticipation of future calls to + /// `decode`. Failing to do so leads to inefficiency. + /// + /// For example, if frames have a fixed length, or if the length of the + /// current frame is known from a header, a possible buffer management + /// strategy is: + /// + /// ```no_run + /// # use std::io; + /// # + /// # use bytes::BytesMut; + /// # use tokio_util::codec::Decoder; + /// # + /// # struct MyCodec; + /// # + /// impl Decoder for MyCodec { + /// // ... + /// # type Item = BytesMut; + /// # type Error = io::Error; + /// + /// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + /// // ... + /// + /// // Reserve enough to complete decoding of the current frame. + /// let current_frame_len: usize = 1000; // Example. + /// // And to start decoding the next frame. + /// let next_frame_header_len: usize = 10; // Example. + /// src.reserve(current_frame_len + next_frame_header_len); + /// + /// return Ok(None); + /// } + /// } + /// ``` + /// + /// An optimal buffer management strategy minimizes reallocations and + /// over-allocations. + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>; + + /// A default method available to be called when there are no more bytes + /// available to be read from the underlying I/O. + /// + /// This method defaults to calling `decode` and returns an error if + /// `Ok(None)` is returned while there is unconsumed data in `buf`. + /// Typically this doesn't need to be implemented unless the framing + /// protocol differs near the end of the stream, or if you need to construct + /// frames _across_ eof boundaries on sources that can be resumed. + /// + /// Note that the `buf` argument may be empty. If a previous call to + /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be + /// called again until it returns `None`, indicating that there are no more + /// frames to yield. This behavior enables returning finalization frames + /// that may not be based on inbound data. + /// + /// Once `None` has been returned, `decode_eof` won't be called again until + /// an attempt to resume the stream has been made, where the underlying stream + /// actually returned more data. + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the [`Framed`] returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Framed`]: crate::codec::Framed + fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self> + where + Self: Sized, + { + Framed::new(io, self) + } +} diff --git a/third_party/rust/tokio-util/src/codec/encoder.rs b/third_party/rust/tokio-util/src/codec/encoder.rs new file mode 100644 index 0000000000..770a10fa9b --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/encoder.rs @@ -0,0 +1,25 @@ +use bytes::BytesMut; +use std::io; + +/// Trait of helper objects to write out messages as bytes, for use with +/// [`FramedWrite`]. +/// +/// [`FramedWrite`]: crate::codec::FramedWrite +pub trait Encoder<Item> { + /// The type of encoding errors. + /// + /// [`FramedWrite`] requires `Encoder`s errors to implement `From<io::Error>` + /// in the interest letting it return `Error`s directly. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + type Error: From<io::Error>; + + /// Encodes a frame into the buffer provided. + /// + /// This method will encode `item` into the byte buffer provided by `dst`. + /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and + /// will be written out when possible. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/third_party/rust/tokio-util/src/codec/framed.rs b/third_party/rust/tokio-util/src/codec/framed.rs new file mode 100644 index 0000000000..d89b8b6dc3 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed.rs @@ -0,0 +1,373 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using + /// the `Encoder` and `Decoder` traits to encode and decode frames. + /// + /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or + /// by using the `new` function seen below. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Decoder::framed`]: crate::codec::Decoder::framed() + pub struct Framed<T, U> { + #[pin] + inner: FramedImpl<T, U, RWFrames> + } +} + +impl<T, U> Framed<T, U> +where + T: AsyncRead + AsyncWrite, +{ + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// Note that, for some byte sources, the stream can be resumed after an EOF + /// by reading from it, even after it has returned `None`. Repeated attempts + /// to do so, without new data available, continue to return `None` without + /// creating more (closing) frames. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn new(inner: T, codec: U) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner, + codec, + state: Default::default(), + }, + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data, + /// with a specific read buffer initial capacity. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner, + codec, + state: RWFrames { + read: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + write: WriteFrame::default(), + }, + }, + } + } +} + +impl<T, U> Framed<T, U> { + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// This objects takes a stream and a readbuffer and a writebuffer. These field + /// can be obtained from an existing `Framed` with the [`into_parts`] method. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decoder`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`into_parts`]: crate::codec::Framed::into_parts() + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner: parts.io, + codec: parts.codec, + state: RWFrames { + read: parts.read_buf.into(), + write: parts.write_buf.into(), + }, + }, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &U { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut U { + &mut self.inner.codec + } + + /// Maps the codec `U` to `C`, preserving the read and write buffers + /// wrapped by `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn map_codec<C, F>(self, map: F) -> Framed<T, C> + where + F: FnOnce(U) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let parts = self.into_parts(); + Framed::from_parts(FramedParts { + io: parts.io, + codec: map(parts.codec), + read_buf: parts.read_buf, + write_buf: parts.write_buf, + _priv: (), + }) + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.read.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.read.buffer + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.write.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.write.buffer + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Consumes the `Framed`, returning its underlying I/O stream, the buffer + /// with unprocessed data, and the codec. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_parts(self) -> FramedParts<T, U> { + FramedParts { + io: self.inner.inner, + codec: self.inner.codec, + read_buf: self.inner.state.read.buffer, + write_buf: self.inner.state.write.buffer, + _priv: (), + } + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, U> Stream for Framed<T, U> +where + T: AsyncRead, + U: Decoder, +{ + type Item = Result<U::Item, U::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, I, U> Sink<I> for Framed<T, U> +where + T: AsyncWrite, + U: Encoder<I>, + U::Error: From<io::Error>, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} + +impl<T, U> fmt::Debug for Framed<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Framed") + .field("io", self.get_ref()) + .field("codec", self.codec()) + .finish() + } +} + +/// `FramedParts` contains an export of the data of a Framed transport. +/// It can be used to construct a new [`Framed`] with a different codec. +/// It contains all current buffers and the inner transport. +/// +/// [`Framed`]: crate::codec::Framed +#[derive(Debug)] +#[allow(clippy::manual_non_exhaustive)] +pub struct FramedParts<T, U> { + /// The inner transport used to read bytes to and write bytes to + pub io: T, + + /// The codec + pub codec: U, + + /// The buffer with read but unprocessed data. + pub read_buf: BytesMut, + + /// A buffer with unprocessed data which are not written yet. + pub write_buf: BytesMut, + + /// This private field allows us to add additional fields in the future in a + /// backwards compatible way. + _priv: (), +} + +impl<T, U> FramedParts<T, U> { + /// Create a new, default, `FramedParts` + pub fn new<I>(io: T, codec: U) -> FramedParts<T, U> + where + U: Encoder<I>, + { + FramedParts { + io, + codec, + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + _priv: (), + } + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_impl.rs b/third_party/rust/tokio-util/src/codec/framed_impl.rs new file mode 100644 index 0000000000..ce1a6db873 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_impl.rs @@ -0,0 +1,308 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_core::ready; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::borrow::{Borrow, BorrowMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tracing::trace; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FramedImpl<T, U, State> { + #[pin] + pub(crate) inner: T, + pub(crate) state: State, + pub(crate) codec: U, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +#[derive(Debug)] +pub(crate) struct ReadFrame { + pub(crate) eof: bool, + pub(crate) is_readable: bool, + pub(crate) buffer: BytesMut, + pub(crate) has_errored: bool, +} + +pub(crate) struct WriteFrame { + pub(crate) buffer: BytesMut, +} + +#[derive(Default)] +pub(crate) struct RWFrames { + pub(crate) read: ReadFrame, + pub(crate) write: WriteFrame, +} + +impl Default for ReadFrame { + fn default() -> Self { + Self { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + has_errored: false, + } + } +} + +impl Default for WriteFrame { + fn default() -> Self { + Self { + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl From<BytesMut> for ReadFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + is_readable: size > 0, + eof: false, + has_errored: false, + } + } +} + +impl From<BytesMut> for WriteFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { buffer } + } +} + +impl Borrow<ReadFrame> for RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut<ReadFrame> for RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow<WriteFrame> for RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut<WriteFrame> for RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} +impl<T, U, R> Stream for FramedImpl<T, U, R> +where + T: AsyncRead, + U: Decoder, + R: BorrowMut<ReadFrame>, +{ + type Item = Result<U::Item, U::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + use crate::util::poll_read_buf; + + let mut pinned = self.project(); + let state: &mut ReadFrame = pinned.state.borrow_mut(); + // The following loops implements a state machine with each state corresponding + // to a combination of the `is_readable` and `eof` flags. States persist across + // loop entries and most state transitions occur with a return. + // + // The initial state is `reading`. + // + // | state | eof | is_readable | has_errored | + // |---------|-------|-------------|-------------| + // | reading | false | false | false | + // | framing | false | true | false | + // | pausing | true | true | false | + // | paused | true | false | false | + // | errored | <any> | <any> | true | + // `decode_eof` returns Err + // ┌────────────────────────────────────────────────────────┐ + // `decode_eof` returns │ │ + // `Ok(Some)` │ │ + // ┌─────┐ │ `decode_eof` returns After returning │ + // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ + // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ + // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ + // Pending read │ │ │ │ │ │ + // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ + // │ │ │ ┌──────┐ │ Pending │ │ + // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ + // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ + // └──┬─▲────┘ └─────┬──┬┘ │ │ + // │ │ │ │ `decode` returns Err │ │ + // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ + // │ read returns Err │ + // └────────────────────────────────────────────────────────────────────────────────────────────┘ + loop { + // Return `None` if we have encountered an error from the underlying decoder + // See: https://github.com/tokio-rs/tokio/issues/3976 + if state.has_errored { + // preparing has_errored -> paused + trace!("Returning None and setting paused"); + state.is_readable = false; + state.has_errored = false; + return Poll::Ready(None); + } + + // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", + // i.e. it _might_ contain data consumable as a frame or closing frame. + // Both signal that there is no such data by returning `None`. + // + // If `decode` couldn't read a frame and the upstream source has returned eof, + // `decode_eof` will attempt to decode the remaining bytes as closing frames. + // + // If the underlying AsyncRead is resumable, we may continue after an EOF, + // but must finish emitting all of it's associated `decode_eof` frames. + // Furthermore, we don't want to emit any `decode_eof` frames on retried + // reads after an EOF unless we've actually read more data. + if state.is_readable { + // pausing or framing + if state.eof { + // pausing + let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + })?; + if frame.is_none() { + state.is_readable = false; // prepare pausing -> paused + } + // implicit pausing -> pausing or pausing -> paused + return Poll::Ready(frame.map(Ok)); + } + + // framing + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + op + })? { + trace!("frame decoded from buffer"); + // implicit framing -> framing + return Poll::Ready(Some(Ok(frame))); + } + + // framing -> reading + state.is_readable = false; + } + // reading or paused + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + state.buffer.reserve(1); + let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( + |err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + }, + )? { + Poll::Ready(ct) => ct, + // implicit reading -> reading or implicit paused -> paused + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + if state.eof { + // We're already at an EOF, and since we've reached this path + // we're also not readable. This implies that we've already finished + // our `decode_eof` handling, so we can simply return `None`. + // implicit paused -> paused + return Poll::Ready(None); + } + // prepare reading -> paused + state.eof = true; + } else { + // prepare paused -> framing or noop reading -> framing + state.eof = false; + } + + // paused -> framing or reading -> framing or reading -> pausing + state.is_readable = true; + } + } +} + +impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> +where + T: AsyncWrite, + U: Encoder<I>, + U::Error: From<io::Error>, + W: BorrowMut<WriteFrame>, +{ + type Error = U::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { + self.as_mut().poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned + .codec + .encode(item, &mut pinned.state.borrow_mut().buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + use crate::util::poll_write_buf; + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.state.borrow_mut().buffer.is_empty() { + let WriteFrame { buffer } = pinned.state.borrow_mut(); + trace!(remaining = buffer.len(), "writing;"); + + let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_read.rs b/third_party/rust/tokio-util/src/codec/framed_read.rs new file mode 100644 index 0000000000..184c567b49 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_read.rs @@ -0,0 +1,199 @@ +use crate::codec::framed_impl::{FramedImpl, ReadFrame}; +use crate::codec::Decoder; + +use futures_core::Stream; +use tokio::io::AsyncRead; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Stream`] of messages decoded from an [`AsyncRead`]. + /// + /// [`Stream`]: futures_core::Stream + /// [`AsyncRead`]: tokio::io::AsyncRead + pub struct FramedRead<T, D> { + #[pin] + inner: FramedImpl<T, D, ReadFrame>, + } +} + +// ===== impl FramedRead ===== + +impl<T, D> FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + /// Creates a new `FramedRead` with the given `decoder`. + pub fn new(inner: T, decoder: D) -> FramedRead<T, D> { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: Default::default(), + }, + } + } + + /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity` + /// initial size. + pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead<T, D> { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + }, + } + } +} + +impl<T, D> FramedRead<T, D> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedRead`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying decoder. + pub fn decoder(&self) -> &D { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_mut(&mut self) -> &mut D { + &mut self.inner.codec + } + + /// Maps the decoder `D` to `C`, preserving the read buffer + /// wrapped by `Framed`. + pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C> + where + F: FnOnce(D) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedRead { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, D> Stream for FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + type Item = Result<D::Item, D::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl<T, I, D> Sink<I> for FramedRead<T, D> +where + T: Sink<I>, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_close(cx) + } +} + +impl<T, D> fmt::Debug for FramedRead<T, D> +where + T: fmt::Debug, + D: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedRead") + .field("inner", &self.get_ref()) + .field("decoder", &self.decoder()) + .field("eof", &self.inner.state.eof) + .field("is_readable", &self.inner.state.is_readable) + .field("buffer", &self.read_buffer()) + .finish() + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_write.rs b/third_party/rust/tokio-util/src/codec/framed_write.rs new file mode 100644 index 0000000000..aa4cec9820 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_write.rs @@ -0,0 +1,178 @@ +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; + +use futures_core::Stream; +use tokio::io::AsyncWrite; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Sink`] of frames encoded to an `AsyncWrite`. + /// + /// [`Sink`]: futures_sink::Sink + pub struct FramedWrite<T, E> { + #[pin] + inner: FramedImpl<T, E, WriteFrame>, + } +} + +impl<T, E> FramedWrite<T, E> +where + T: AsyncWrite, +{ + /// Creates a new `FramedWrite` with the given `encoder`. + pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> { + FramedWrite { + inner: FramedImpl { + inner, + codec: encoder, + state: WriteFrame::default(), + }, + } + } +} + +impl<T, E> FramedWrite<T, E> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedWrite`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying encoder. + pub fn encoder(&self) -> &E { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_mut(&mut self) -> &mut E { + &mut self.inner.codec + } + + /// Maps the encoder `E` to `C`, preserving the write buffer + /// wrapped by `Framed`. + pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C> + where + F: FnOnce(E) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedWrite { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { + self.project().inner.project().codec + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, I, E> Sink<I> for FramedWrite<T, E> +where + T: AsyncWrite, + E: Encoder<I>, + E::Error: From<io::Error>, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} + +// This impl just defers to the underlying T: Stream +impl<T, D> Stream for FramedWrite<T, D> +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.project().inner.poll_next(cx) + } +} + +impl<T, U> fmt::Debug for FramedWrite<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedWrite") + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) + .finish() + } +} diff --git a/third_party/rust/tokio-util/src/codec/length_delimited.rs b/third_party/rust/tokio-util/src/codec/length_delimited.rs new file mode 100644 index 0000000000..93d2f180d0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/length_delimited.rs @@ -0,0 +1,1047 @@ +//! Frame a stream of bytes based on a length prefix +//! +//! Many protocols delimit their frames by prefacing frame data with a +//! frame head that specifies the length of the frame. The +//! `length_delimited` module provides utilities for handling the length +//! based framing. This allows the consumer to work with entire frames +//! without having to worry about buffering or other framing logic. +//! +//! # Getting started +//! +//! If implementing a protocol from scratch, using length delimited framing +//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will +//! return a length delimited codec using default configuration values. +//! This can then be used to construct a framer to adapt a full-duplex +//! byte stream into a stream of frames. +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! fn bind_transport<T: AsyncRead + AsyncWrite>(io: T) +//! -> Framed<T, LengthDelimitedCodec> +//! { +//! Framed::new(io, LengthDelimitedCodec::new()) +//! } +//! # pub fn main() {} +//! ``` +//! +//! The returned transport implements `Sink + Stream` for `BytesMut`. It +//! encodes the frame with a big-endian `u32` header denoting the frame +//! payload length: +//! +//! ```text +//! +----------+--------------------------------+ +//! | len: u32 | frame payload | +//! +----------+--------------------------------+ +//! ``` +//! +//! Specifically, given the following: +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! use futures::SinkExt; +//! use bytes::Bytes; +//! +//! async fn write_frame<T>(io: T) -> Result<(), Box<dyn std::error::Error>> +//! where +//! T: AsyncRead + AsyncWrite + Unpin, +//! { +//! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); +//! let frame = Bytes::from("hello world"); +//! +//! transport.send(frame).await?; +//! Ok(()) +//! } +//! ``` +//! +//! The encoded frame will look like this: +//! +//! ```text +//! +---- len: u32 ----+---- data ----+ +//! | \x00\x00\x00\x0b | hello world | +//! +------------------+--------------+ +//! ``` +//! +//! # Decoding +//! +//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], +//! such that each yielded [`BytesMut`] value contains the contents of an +//! entire frame. There are many configuration parameters enabling +//! [`FramedRead`] to handle a wide range of protocols. Here are some +//! examples that will cover the various options at a high level. +//! +//! ## Example 1 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(0) // default value +//! .num_skip(0) // Do not strip frame header +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! The value of the length field is 11 (`\x0B`) which represents the length +//! of the payload, `hello world`. By default, [`FramedRead`] assumes that +//! the length field represents the number of bytes that **follows** the +//! length field. Thus, the entire frame has a length of 13: 2 bytes for the +//! frame head + 11 bytes for the payload. +//! +//! ## Example 2 +//! +//! The following will parse a `u16` length field at offset 0, omitting the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(0) // default value +//! // `num_skip` is not needed, the default is to skip +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +--- Payload ---+ +//! | \x00\x0B | Hello world | --> | Hello world | +//! +----------+---------------+ +---------------+ +//! ``` +//! +//! This is similar to the first example, the only difference is that the +//! frame head is **not** included in the yielded `BytesMut` value. +//! +//! ## Example 3 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. In this case, the length field +//! **includes** the frame head length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(-2) // size of head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! In most cases, the length field represents the length of the payload +//! only, as shown in the previous examples. However, in some protocols the +//! length field represents the length of the whole frame, including the +//! head. In such cases, we specify a negative `length_adjustment` to adjust +//! the value provided in the frame head to represent the payload length. +//! +//! ## Example 4 +//! +//! The following will parse a 3 byte length field at offset 0 in a 5 byte +//! frame head, including the frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(2) // remaining head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! +//! DECODED +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! ``` +//! +//! A more advanced example that shows a case where there is extra frame +//! head data between the length field and the payload. In such cases, it is +//! usually desirable to include the frame head as part of the yielded +//! `BytesMut`. This lets consumers of the length delimited framer to +//! process the frame head as needed. +//! +//! The positive `length_adjustment` value lets `FramedRead` factor in the +//! additional head into the frame length calculation. +//! +//! ## Example 5 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::<u16>() +//! .length_adjustment(1) // length of hdr2 +//! .num_skip(3) // length of hdr1 + LEN +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0B | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! The length field is situated in the middle of the frame head. In this +//! case, the first byte in the frame head could be a version or some other +//! identifier that is not needed for processing. On the other hand, the +//! second half of the head is needed. +//! +//! `length_field_offset` indicates how many bytes to skip before starting +//! to read the length field. `length_adjustment` is the number of bytes to +//! skip starting at the end of the length field. In this case, it is the +//! second half of the head. +//! +//! ## Example 6 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. In this case, the length field **includes** the frame head +//! length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::<u16>() +//! .length_adjustment(-3) // length of hdr1 + LEN, negative +//! .num_skip(3) +//! .new_read(io); +//! # } +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0F | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! Similar to the example above, the difference is that the length field +//! represents the length of the entire frame instead of just the payload. +//! The length of `hdr1` and `len` must be counted in `length_adjustment`. +//! Note that the length of `hdr2` does **not** need to be explicitly set +//! anywhere because it already is factored into the total frame length that +//! is read from the byte stream. +//! +//! ## Example 7 +//! +//! The following will parse a 3 byte length field at offset 0 in a 4 byte +//! frame head, excluding the 4th byte from the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(0) // default value +//! .num_skip(4) // skip the first 4 bytes +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +------- len ------+--- Payload ---+ +--- Payload ---+ +//! | \x00\x00\x0B\xFF | Hello world | => | Hello world | +//! +------------------+---------------+ +---------------+ +//! ``` +//! +//! A simple example where there are unused bytes between the length field +//! and the payload. +//! +//! # Encoding +//! +//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], +//! such that each submitted [`BytesMut`] is prefaced by a length field. +//! There are fewer configuration options than [`FramedRead`]. Given +//! protocols that have more complex frame heads, an encoder should probably +//! be written by hand using [`Encoder`]. +//! +//! Here is a simple example, given a `FramedWrite` with the following +//! configuration: +//! +//! ``` +//! # use tokio::io::AsyncWrite; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn write_frame<T: AsyncWrite>(io: T) { +//! # let _ = +//! LengthDelimitedCodec::builder() +//! .length_field_type::<u16>() +//! .new_write(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! A payload of `hello world` will be encoded as: +//! +//! ```text +//! +- len: u16 -+---- data ----+ +//! | \x00\x0b | hello world | +//! +------------+--------------+ +//! ``` +//! +//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new +//! [`FramedRead`]: struct@FramedRead +//! [`FramedWrite`]: struct@FramedWrite +//! [`AsyncRead`]: trait@tokio::io::AsyncRead +//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite +//! [`Encoder`]: trait@Encoder +//! [`BytesMut`]: bytes::BytesMut + +use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::error::Error as StdError; +use std::io::{self, Cursor}; +use std::{cmp, fmt, mem}; + +/// Configure length delimited `LengthDelimitedCodec`s. +/// +/// `Builder` enables constructing configured length delimited codecs. Note +/// that not all configuration settings apply to both encoding and decoding. See +/// the documentation for specific methods for more detail. +#[derive(Debug, Clone, Copy)] +pub struct Builder { + // Maximum frame length + max_frame_len: usize, + + // Number of bytes representing the field length + length_field_len: usize, + + // Number of bytes in the header before the length field + length_field_offset: usize, + + // Adjust the length specified in the header field by this amount + length_adjustment: isize, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: Option<usize>, + + // Length field byte order (little or big endian) + length_field_is_big_endian: bool, +} + +/// An error when the number of bytes read is more than max frame length. +pub struct LengthDelimitedCodecError { + _priv: (), +} + +/// A codec for frames delimited by a frame head specifying their lengths. +/// +/// This allows the consumer to work with entire frames without having to worry +/// about buffering or other framing logic. +/// +/// See [module level] documentation for more detail. +/// +/// [module level]: index.html +#[derive(Debug, Clone)] +pub struct LengthDelimitedCodec { + // Configuration values + builder: Builder, + + // Read state + state: DecodeState, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +// ===== impl LengthDelimitedCodec ====== + +impl LengthDelimitedCodec { + /// Creates a new `LengthDelimitedCodec` with the default configuration values. + pub fn new() -> Self { + Self { + builder: Builder::new(), + state: DecodeState::Head, + } + } + + /// Creates a new length delimited codec builder with default configuration + /// values. + pub fn builder() -> Builder { + Builder::new() + } + + /// Returns the current max frame setting + /// + /// This is the largest size this codec will accept from the wire. Larger + /// frames will be rejected. + pub fn max_frame_length(&self) -> usize { + self.builder.max_frame_len + } + + /// Updates the max frame setting. + /// + /// The change takes effect the next time a frame is decoded. In other + /// words, if a frame is currently in process of being decoded with a frame + /// size greater than `val` but less than the max frame length in effect + /// before calling this function, then the frame will be allowed. + pub fn set_max_frame_length(&mut self, val: usize) { + self.builder.max_frame_length(val); + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> { + let head_len = self.builder.num_head_bytes(); + let field_len = self.builder.length_field_len; + + if src.len() < head_len { + // Not enough data + return Ok(None); + } + + let n = { + let mut src = Cursor::new(&mut *src); + + // Skip the required bytes + src.advance(self.builder.length_field_offset); + + // match endianness + let n = if self.builder.length_field_is_big_endian { + src.get_uint(field_len) + } else { + src.get_uint_le(field_len) + }; + + if n > self.builder.max_frame_len as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // The check above ensures there is no overflow + let n = n as usize; + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_sub(-self.builder.length_adjustment as usize) + } else { + n.checked_add(self.builder.length_adjustment as usize) + }; + + // Error handling + match n { + Some(n) => n, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + )); + } + } + }; + + let num_skip = self.builder.get_num_skip(); + + if num_skip > 0 { + src.advance(num_skip); + } + + // Ensure that the buffer has enough space to read the incoming + // payload + src.reserve(n); + + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option<BytesMut> { + // At this point, the buffer has already had the required capacity + // reserved. All there is to do is read. + if src.len() < n { + return None; + } + + Some(src.split_to(n)) + } +} + +impl Decoder for LengthDelimitedCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src) { + Some(data) => { + // Update the decode state + self.state = DecodeState::Head; + + // Make sure the buffer has enough space to read the next head + src.reserve(self.builder.num_head_bytes()); + + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder<Bytes> for LengthDelimitedCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { + let n = data.len(); + + if n > self.builder.max_frame_len { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_add(-self.builder.length_adjustment as usize) + } else { + n.checked_sub(self.builder.length_adjustment as usize) + }; + + let n = n.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + ) + })?; + + // Reserve capacity in the destination buffer to fit the frame and + // length field (plus adjustment). + dst.reserve(self.builder.length_field_len + n); + + if self.builder.length_field_is_big_endian { + dst.put_uint(n as u64, self.builder.length_field_len); + } else { + dst.put_uint_le(n as u64, self.builder.length_field_len); + } + + // Write the frame to the buffer + dst.extend_from_slice(&data[..]); + + Ok(()) + } +} + +impl Default for LengthDelimitedCodec { + fn default() -> Self { + Self::new() + } +} + +// ===== impl Builder ===== + +mod builder { + /// Types that can be used with `Builder::length_field_type`. + pub trait LengthFieldType {} + + impl LengthFieldType for u8 {} + impl LengthFieldType for u16 {} + impl LengthFieldType for u32 {} + impl LengthFieldType for u64 {} + + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32", + target_pointer_width = "64", + ))] + impl LengthFieldType for usize {} +} + +impl Builder { + /// Creates a new length delimited codec builder with default configuration + /// values. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new() -> Builder { + Builder { + // Default max frame length of 8MB + max_frame_len: 8 * 1_024 * 1_024, + + // Default byte length of 4 + length_field_len: 4, + + // Default to the header field being at the start of the header. + length_field_offset: 0, + + length_adjustment: 0, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: None, + + // Default to reading the length field in network (big) endian. + length_field_is_big_endian: true, + } + } + + /// Read the length field as a big endian integer + /// + /// This is the default setting. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .big_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn big_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = true; + self + } + + /// Read the length field as a little endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .little_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn little_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = false; + self + } + + /// Read the length field as a native endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .native_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn native_endian(&mut self) -> &mut Self { + if cfg!(target_endian = "big") { + self.big_endian() + } else { + self.little_endian() + } + } + + /// Sets the max frame length in bytes + /// + /// This configuration option applies to both encoding and decoding. The + /// default value is 8MB. + /// + /// When decoding, the length field read from the byte stream is checked + /// against this setting **before** any adjustments are applied. When + /// encoding, the length of the submitted payload is checked against this + /// setting. + /// + /// When frames exceed the max length, an `io::Error` with the custom value + /// of the `LengthDelimitedCodecError` type will be returned. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .max_frame_length(8 * 1024 * 1024) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn max_frame_length(&mut self, val: usize) -> &mut Self { + self.max_frame_len = val; + self + } + + /// Sets the unsigned integer type used to represent the length field. + /// + /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on + /// 64-bit targets). + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u32>() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + /// + /// Unlike [`Builder::length_field_length`], this does not fail at runtime + /// and instead produces a compile error: + /// + /// ```compile_fail + /// # use tokio::io::AsyncRead; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u128>() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_type<T: builder::LengthFieldType>(&mut self) -> &mut Self { + self.length_field_length(mem::size_of::<T>()) + } + + /// Sets the number of bytes used to represent the length field + /// + /// The default value is `4`. The max value is `8`. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_length(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_length(&mut self, val: usize) -> &mut Self { + assert!(val > 0 && val <= 8, "invalid length field length"); + self.length_field_len = val; + self + } + + /// Sets the number of bytes in the header before the length field + /// + /// This configuration option only applies to decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(1) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_offset(&mut self, val: usize) -> &mut Self { + self.length_field_offset = val; + self + } + + /// Delta between the payload length specified in the header and the real + /// payload length + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_adjustment(-2) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_adjustment(&mut self, val: isize) -> &mut Self { + self.length_adjustment = val; + self + } + + /// Sets the number of bytes to skip before reading the payload + /// + /// Default value is `length_field_len + length_field_offset` + /// + /// This configuration option only applies to decoding + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .num_skip(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn num_skip(&mut self, val: usize) -> &mut Self { + self.num_skip = Some(val); + self + } + + /// Create a configured length delimited `LengthDelimitedCodec` + /// + /// # Examples + /// + /// ``` + /// use tokio_util::codec::LengthDelimitedCodec; + /// # pub fn main() { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_codec(); + /// # } + /// ``` + pub fn new_codec(&self) -> LengthDelimitedCodec { + LengthDelimitedCodec { + builder: *self, + state: DecodeState::Head, + } + } + + /// Create a configured length delimited `FramedRead` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec> + where + T: AsyncRead, + { + FramedRead::new(upstream, self.new_codec()) + } + + /// Create a configured length delimited `FramedWrite` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncWrite; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncWrite>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u16>() + /// .new_write(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec> + where + T: AsyncWrite, + { + FramedWrite::new(inner, self.new_codec()) + } + + /// Create a configured length delimited `Framed` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) { + /// # let _ = + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u16>() + /// .new_framed(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec> + where + T: AsyncRead + AsyncWrite, + { + Framed::new(inner, self.new_codec()) + } + + fn num_head_bytes(&self) -> usize { + let num = self.length_field_offset + self.length_field_len; + cmp::max(num, self.num_skip.unwrap_or(0)) + } + + fn get_num_skip(&self) -> usize { + self.num_skip + .unwrap_or(self.length_field_offset + self.length_field_len) + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +// ===== impl LengthDelimitedCodecError ===== + +impl fmt::Debug for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LengthDelimitedCodecError").finish() + } +} + +impl fmt::Display for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("frame size too big") + } +} + +impl StdError for LengthDelimitedCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/lines_codec.rs b/third_party/rust/tokio-util/src/codec/lines_codec.rs new file mode 100644 index 0000000000..7a0a8f0454 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/lines_codec.rs @@ -0,0 +1,230 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LinesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl LinesCodec { + /// Returns a `LinesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() + pub fn new() -> LinesCodec { + LinesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + /// Returns a `LinesCodec` with a maximum line length limit. + /// + /// If this is set, calls to `LinesCodec::decode` will return a + /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that line until a newline + /// character is reached, returning `None` until the line over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `LinesCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the line currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any `\n` characters, causing unbounded memory consumption. + /// + /// [`LinesCodecError`]: crate::codec::LinesCodecError + pub fn new_with_max_length(max_length: usize) -> Self { + LinesCodec { + max_length, + ..LinesCodec::new() + } + } + + /// Returns the maximum line length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new(); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new_with_max_length(256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn utf8(buf: &[u8]) -> Result<&str, io::Error> { + str::from_utf8(buf) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +impl Decoder for LinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + let line = utf8(line)?; + return Ok(Some(line.to_string())); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let line = buf.split_to(buf.len()); + let line = without_carriage_return(&line); + let line = utf8(line)?; + self.next_index = 0; + Some(line.to_string()) + } + } + }) + } +} + +impl<T> Encoder<T> for LinesCodec +where + T: AsRef<str>, +{ + type Error = LinesCodecError; + + fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + let line = line.as_ref(); + buf.reserve(line.len() + 1); + buf.put(line.as_bytes()); + buf.put_u8(b'\n'); + Ok(()) + } +} + +impl Default for LinesCodec { + fn default() -> Self { + Self::new() + } +} + +/// An error occurred while encoding or decoding a line. +#[derive(Debug)] +pub enum LinesCodecError { + /// The maximum line length was exceeded. + MaxLineLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From<io::Error> for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/mod.rs b/third_party/rust/tokio-util/src/codec/mod.rs new file mode 100644 index 0000000000..2295176bdc --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/mod.rs @@ -0,0 +1,290 @@ +//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink +//! +//! Raw I/O objects work with byte sequences, but higher-level code usually +//! wants to batch these into meaningful chunks, called "frames". +//! +//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and +//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. +//! Framed streams are also known as transports. +//! +//! # The Decoder trait +//! +//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an +//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify +//! how sequences of bytes are turned into a sequence of frames, and to +//! determine where the boundaries between frames are. The job of the +//! `FramedRead` is to repeatedly switch between reading more data from the IO +//! resource, and asking the decoder whether we have received enough data to +//! decode another frame of data. +//! +//! The main method on the `Decoder` trait is the [`decode`] method. This method +//! takes as argument the data that has been read so far, and when it is called, +//! it will be in one of the following situations: +//! +//! 1. The buffer contains less than a full frame. +//! 2. The buffer contains exactly a full frame. +//! 3. The buffer contains more than a full frame. +//! +//! In the first situation, the decoder should return `Ok(None)`. +//! +//! In the second situation, the decoder should clear the provided buffer and +//! return `Ok(Some(the_decoded_frame))`. +//! +//! In the third situation, the decoder should use a method such as [`split_to`] +//! or [`advance`] to modify the buffer such that the frame is removed from the +//! buffer, but any data in the buffer after that frame should still remain in +//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in +//! this case. +//! +//! Finally the decoder may return an error if the data is invalid in some way. +//! The decoder should _not_ return an error just because it has yet to receive +//! a full frame. +//! +//! It is guaranteed that, from one call to `decode` to another, the provided +//! buffer will contain the exact same data as before, except that if more data +//! has arrived through the IO resource, that data will have been appended to +//! the buffer. This means that reading frames from a `FramedRead` is +//! essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncReadExt; +//! # // This uses async_stream to create an example that compiles. +//! # fn foo() -> impl futures_core::Stream<Item = std::io::Result<bytes::BytesMut>> { async_stream::try_stream! { +//! # use tokio_util::codec::Decoder; +//! # let mut decoder = tokio_util::codec::BytesCodec::new(); +//! # let io_resource = &mut &[0u8, 1, 2, 3][..]; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! // The read_buf call will append to buf rather than overwrite existing data. +//! let len = io_resource.read_buf(&mut buf).await?; +//! +//! if len == 0 { +//! while let Some(frame) = decoder.decode_eof(&mut buf)? { +//! yield frame; +//! } +//! break; +//! } +//! +//! while let Some(frame) = decoder.decode(&mut buf)? { +//! yield frame; +//! } +//! } +//! # }} +//! ``` +//! The example above uses `yield` whenever the `Stream` produces an item. +//! +//! ## Example decoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The decoder fails with an error +//! if the string data is not valid utf-8 or too long. +//! +//! Such a decoder can be written like this: +//! ``` +//! use tokio_util::codec::Decoder; +//! use bytes::{BytesMut, Buf}; +//! +//! struct MyStringDecoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Decoder for MyStringDecoder { +//! type Item = String; +//! type Error = std::io::Error; +//! +//! fn decode( +//! &mut self, +//! src: &mut BytesMut +//! ) -> Result<Option<Self::Item>, Self::Error> { +//! if src.len() < 4 { +//! // Not enough data to read length marker. +//! return Ok(None); +//! } +//! +//! // Read length marker. +//! let mut length_bytes = [0u8; 4]; +//! length_bytes.copy_from_slice(&src[..4]); +//! let length = u32::from_le_bytes(length_bytes) as usize; +//! +//! // Check that the length is not too large to avoid a denial of +//! // service attack where the server runs out of memory. +//! if length > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", length) +//! )); +//! } +//! +//! if src.len() < 4 + length { +//! // The full string has not yet arrived. +//! // +//! // We reserve more space in the buffer. This is not strictly +//! // necessary, but is a good idea performance-wise. +//! src.reserve(4 + length - src.len()); +//! +//! // We inform the Framed that we need more bytes to form the next +//! // frame. +//! return Ok(None); +//! } +//! +//! // Use advance to modify src such that it no longer contains +//! // this frame. +//! let data = src[4..4 + length].to_vec(); +//! src.advance(4 + length); +//! +//! // Convert the data to a string, or fail if it is not valid utf-8. +//! match String::from_utf8(data) { +//! Ok(string) => Ok(Some(string)), +//! Err(utf8_error) => { +//! Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! utf8_error.utf8_error(), +//! )) +//! }, +//! } +//! } +//! } +//! ``` +//! +//! # The Encoder trait +//! +//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn +//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to +//! specify how frames are turned into a sequences of bytes. The job of the +//! `FramedWrite` is to take the resulting sequence of bytes and write it to the +//! IO resource. +//! +//! The main method on the `Encoder` trait is the [`encode`] method. This method +//! takes an item that is being written, and a buffer to write the item to. The +//! buffer may already contain data, and in this case, the encoder should append +//! the new frame the to buffer rather than overwrite the existing data. +//! +//! It is guaranteed that, from one call to `encode` to another, the provided +//! buffer will contain the exact same data as before, except that some of the +//! data may have been removed from the front of the buffer. Writing to a +//! `FramedWrite` is essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncWriteExt; +//! use bytes::Buf; // for advance +//! # use tokio_util::codec::Encoder; +//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() } +//! # async fn no_more_frames() { } +//! # #[tokio::main] async fn main() -> std::io::Result<()> { +//! # let mut io_resource = tokio::io::sink(); +//! # let mut encoder = tokio_util::codec::BytesCodec::new(); +//! +//! const MAX: usize = 8192; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! tokio::select! { +//! num_written = io_resource.write(&buf), if !buf.is_empty() => { +//! buf.advance(num_written?); +//! }, +//! frame = next_frame(), if buf.len() < MAX => { +//! encoder.encode(frame, &mut buf)?; +//! }, +//! _ = no_more_frames() => { +//! io_resource.write_all(&buf).await?; +//! io_resource.shutdown().await?; +//! return Ok(()); +//! }, +//! } +//! } +//! # } +//! ``` +//! Here the `next_frame` method corresponds to any frames you write to the +//! `FramedWrite`. The `no_more_frames` method corresponds to closing the +//! `FramedWrite` with [`SinkExt::close`]. +//! +//! ## Example encoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The encoder will fail if the +//! string is too long. +//! +//! Such an encoder can be written like this: +//! ``` +//! use tokio_util::codec::Encoder; +//! use bytes::BytesMut; +//! +//! struct MyStringEncoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Encoder<String> for MyStringEncoder { +//! type Error = std::io::Error; +//! +//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> { +//! // Don't send a string if it is longer than the other end will +//! // accept. +//! if item.len() > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", item.len()) +//! )); +//! } +//! +//! // Convert the length into a byte array. +//! // The cast to u32 cannot overflow due to the length check above. +//! let len_slice = u32::to_le_bytes(item.len() as u32); +//! +//! // Reserve space in the buffer. +//! dst.reserve(4 + item.len()); +//! +//! // Write the length and string to the buffer. +//! dst.extend_from_slice(&len_slice); +//! dst.extend_from_slice(item.as_bytes()); +//! Ok(()) +//! } +//! } +//! ``` +//! +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Stream`]: futures_core::Stream +//! [`Sink`]: futures_sink::Sink +//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close +//! [`FramedRead`]: struct@crate::codec::FramedRead +//! [`FramedWrite`]: struct@crate::codec::FramedWrite +//! [`Framed`]: struct@crate::codec::Framed +//! [`Decoder`]: trait@crate::codec::Decoder +//! [`decode`]: fn@crate::codec::Decoder::decode +//! [`encode`]: fn@crate::codec::Encoder::encode +//! [`split_to`]: fn@bytes::BytesMut::split_to +//! [`advance`]: fn@bytes::Buf::advance + +mod bytes_codec; +pub use self::bytes_codec::BytesCodec; + +mod decoder; +pub use self::decoder::Decoder; + +mod encoder; +pub use self::encoder::Encoder; + +mod framed_impl; +#[allow(unused_imports)] +pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +mod framed; +pub use self::framed::{Framed, FramedParts}; + +mod framed_read; +pub use self::framed_read::FramedRead; + +mod framed_write; +pub use self::framed_write::FramedWrite; + +pub mod length_delimited; +pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; + +mod lines_codec; +pub use self::lines_codec::{LinesCodec, LinesCodecError}; + +mod any_delimiter_codec; +pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError}; diff --git a/third_party/rust/tokio-util/src/compat.rs b/third_party/rust/tokio-util/src/compat.rs new file mode 100644 index 0000000000..6a8802d969 --- /dev/null +++ b/third_party/rust/tokio-util/src/compat.rs @@ -0,0 +1,274 @@ +//! Compatibility between the `tokio::io` and `futures-io` versions of the +//! `AsyncRead` and `AsyncWrite` traits. +use futures_core::ready; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A compatibility layer that allows conversion between the + /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. + #[derive(Copy, Clone, Debug)] + pub struct Compat<T> { + #[pin] + inner: T, + seek_pos: Option<io::SeekFrom>, + } +} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. +pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `tokio_io::AsyncRead`. + fn compat(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. +pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `tokio::io::AsyncWrite`. + fn compat_write(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. +pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncRead`. + fn compat(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. +pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncWrite`. + fn compat_write(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {} + +// === impl Compat === + +impl<T> Compat<T> { + fn new(inner: T) -> Self { + Self { + inner, + seek_pos: None, + } + } + + /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Returns the wrapped item. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl<T> tokio::io::AsyncRead for Compat<T> +where + T: futures_io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // We can't trust the inner type to not peak at the bytes, + // so we must defensively initialize the buffer. + let slice = buf.initialize_unfilled(); + let n = ready!(futures_io::AsyncRead::poll_read( + self.project().inner, + cx, + slice + ))?; + buf.advance(n); + Poll::Ready(Ok(())) + } +} + +impl<T> futures_io::AsyncRead for Compat<T> +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + slice: &mut [u8], + ) -> Poll<io::Result<usize>> { + let mut buf = tokio::io::ReadBuf::new(slice); + ready!(tokio::io::AsyncRead::poll_read( + self.project().inner, + cx, + &mut buf + ))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +impl<T> tokio::io::AsyncBufRead for Compat<T> +where + T: futures_io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&'a [u8]>> { + futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + futures_io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl<T> futures_io::AsyncBufRead for Compat<T> +where + T: tokio::io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&'a [u8]>> { + tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + tokio::io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl<T> tokio::io::AsyncWrite for Compat<T> +where + T: futures_io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + futures_io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + futures_io::AsyncWrite::poll_close(self.project().inner, cx) + } +} + +impl<T> futures_io::AsyncWrite for Compat<T> +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } +} + +impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> { + fn poll_seek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: io::SeekFrom, + ) -> Poll<io::Result<u64>> { + if self.seek_pos != Some(pos) { + self.as_mut().project().inner.start_seek(pos)?; + *self.as_mut().project().seek_pos = Some(pos); + } + let res = ready!(self.as_mut().project().inner.poll_complete(cx)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> { + fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { + *self.as_mut().project().seek_pos = Some(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let pos = match self.seek_pos { + None => { + // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + Some(pos) => pos, + }; + let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +#[cfg(unix)] +impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(windows)] +impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + self.inner.as_raw_handle() + } +} diff --git a/third_party/rust/tokio-util/src/context.rs b/third_party/rust/tokio-util/src/context.rs new file mode 100644 index 0000000000..a7a5e02949 --- /dev/null +++ b/third_party/rust/tokio-util/src/context.rs @@ -0,0 +1,190 @@ +//! Tokio context aware futures utilities. +//! +//! This module includes utilities around integrating tokio with other runtimes +//! by allowing the context to be attached to futures. This allows spawning +//! futures on other executors while still using tokio to drive them. This +//! can be useful if you need to use a tokio based library in an executor/runtime +//! that does not provide a tokio context. + +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::runtime::{Handle, Runtime}; + +pin_project! { + /// `TokioContext` allows running futures that must be inside Tokio's + /// context on a non-Tokio runtime. + /// + /// It contains a [`Handle`] to the runtime. A handle to the runtime can be + /// obtain by calling the [`Runtime::handle()`] method. + /// + /// Note that the `TokioContext` wrapper only works if the `Runtime` it is + /// connected to has not yet been destroyed. You must keep the `Runtime` + /// alive until the future has finished executing. + /// + /// **Warning:** If `TokioContext` is used together with a [current thread] + /// runtime, that runtime must be inside a call to `block_on` for the + /// wrapped future to work. For this reason, it is recommended to use a + /// [multi thread] runtime, even if you configure it to only spawn one + /// worker thread. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`Handle`]: struct@tokio::runtime::Handle + /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle + /// [`RuntimeExt`]: trait@crate::context::RuntimeExt + /// [`new_static`]: fn@Self::new_static + /// [`sleep`]: fn@tokio::time::sleep + /// [current thread]: fn@tokio::runtime::Builder::new_current_thread + /// [enables time]: fn@tokio::runtime::Builder::enable_time + /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread + pub struct TokioContext<F> { + #[pin] + inner: F, + handle: Handle, + } +} + +impl<F> TokioContext<F> { + /// Associate the provided future with the context of the runtime behind + /// the provided `Handle`. + /// + /// This constructor uses a `'static` lifetime to opt-out of checking that + /// the runtime still exists. + /// + /// # Examples + /// + /// This is the same as the example above, but uses the `new` constructor + /// rather than [`RuntimeExt::wrap`]. + /// + /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::TokioContext; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// let fut = TokioContext::new( + /// async { sleep(Duration::from_millis(2)).await }, + /// rt.handle().clone(), + /// ); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + pub fn new(future: F, handle: Handle) -> TokioContext<F> { + TokioContext { + inner: future, + handle, + } + } + + /// Obtain a reference to the handle inside this `TokioContext`. + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Remove the association between the Tokio runtime and the wrapped future. + pub fn into_inner(self) -> F { + self.inner + } +} + +impl<F: Future> Future for TokioContext<F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + let handle = me.handle; + let fut = me.inner; + + let _enter = handle.enter(); + fut.poll(cx) + } +} + +/// Extension trait that simplifies bundling a `Handle` with a `Future`. +pub trait RuntimeExt { + /// Create a [`TokioContext`] that wraps the provided future and runs it in + /// this runtime's context. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`TokioContext`]: struct@crate::context::TokioContext + /// [`sleep`]: fn@tokio::time::sleep + /// [enables time]: fn@tokio::runtime::Builder::enable_time + fn wrap<F: Future>(&self, fut: F) -> TokioContext<F>; +} + +impl RuntimeExt for Runtime { + fn wrap<F: Future>(&self, fut: F) -> TokioContext<F> { + TokioContext { + inner: fut, + handle: self.handle().clone(), + } + } +} diff --git a/third_party/rust/tokio-util/src/either.rs b/third_party/rust/tokio-util/src/either.rs new file mode 100644 index 0000000000..9225e53ca6 --- /dev/null +++ b/third_party/rust/tokio-util/src/either.rs @@ -0,0 +1,188 @@ +//! Module defining an Either type. +use std::{ + future::Future, + io::SeekFrom, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; + +/// Combines two different futures, streams, or sinks having the same associated types into a single type. +/// +/// This type implements common asynchronous traits such as [`Future`] and those in Tokio. +/// +/// [`Future`]: std::future::Future +/// +/// # Example +/// +/// The following code will not work: +/// +/// ```compile_fail +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// some_async_function() +/// } else { +/// other_async_function() // <- Will print: "`if` and `else` have incompatible types" +/// }; +/// +/// println!("Result is {}", result.await); +/// } +/// ``` +/// +// This is because although the output types for both futures is the same, the exact future +// types are different, but the compiler must be able to choose a single type for the +// `result` variable. +/// +/// When the output type is the same, we can wrap each future in `Either` to avoid the +/// issue: +/// +/// ``` +/// use tokio_util::either::Either; +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// Either::Left(some_async_function()) +/// } else { +/// Either::Right(other_async_function()) +/// }; +/// +/// let value = result.await; +/// println!("Result is {}", value); +/// # assert_eq!(value, 10); +/// } +/// ``` +#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. +#[derive(Debug, Clone)] +pub enum Either<L, R> { + Left(L), + Right(R), +} + +/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. +/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either +/// enum variant held in `self`. +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), + Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), + } + } + } +} + +impl<L, R, O> Future for Either<L, R> +where + L: Future<Output = O>, + R: Future<Output = O>, +{ + type Output = O; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + delegate_call!(self.poll(cx)) + } +} + +impl<L, R> AsyncRead for Either<L, R> +where + L: AsyncRead, + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<()>> { + delegate_call!(self.poll_read(cx, buf)) + } +} + +impl<L, R> AsyncBufRead for Either<L, R> +where + L: AsyncBufRead, + R: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> { + delegate_call!(self.poll_fill_buf(cx)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + delegate_call!(self.consume(amt)) + } +} + +impl<L, R> AsyncSeek for Either<L, R> +where + L: AsyncSeek, + R: AsyncSeek, +{ + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { + delegate_call!(self.start_seek(position)) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> { + delegate_call!(self.poll_complete(cx)) + } +} + +impl<L, R> AsyncWrite for Either<L, R> +where + L: AsyncWrite, + R: AsyncWrite, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { + delegate_call!(self.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { + delegate_call!(self.poll_shutdown(cx)) + } +} + +impl<L, R> futures_core::stream::Stream for Either<L, R> +where + L: futures_core::stream::Stream, + R: futures_core::stream::Stream<Item = L::Item>, +{ + type Item = L::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + delegate_call!(self.poll_next(cx)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{repeat, AsyncReadExt, Repeat}; + use tokio_stream::{once, Once, StreamExt}; + + #[tokio::test] + async fn either_is_stream() { + let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1)); + + assert_eq!(Some(1u32), either.next().await); + } + + #[tokio::test] + async fn either_is_async_read() { + let mut buffer = [0; 3]; + let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101)); + + either.read_exact(&mut buffer).await.unwrap(); + assert_eq!(buffer, [0b101, 0b101, 0b101]); + } +} diff --git a/third_party/rust/tokio-util/src/io/mod.rs b/third_party/rust/tokio-util/src/io/mod.rs new file mode 100644 index 0000000000..eb48a21fb9 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/mod.rs @@ -0,0 +1,24 @@ +//! Helpers for IO related tasks. +//! +//! The stream types are often used in combination with hyper or reqwest, as they +//! allow converting between a hyper [`Body`] and [`AsyncRead`]. +//! +//! The [`SyncIoBridge`] type converts from the world of async I/O +//! to synchronous I/O; this may often come up when using synchronous APIs +//! inside [`tokio::task::spawn_blocking`]. +//! +//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html +//! [`AsyncRead`]: tokio::io::AsyncRead + +mod read_buf; +mod reader_stream; +mod stream_reader; +cfg_io_util! { + mod sync_bridge; + pub use self::sync_bridge::SyncIoBridge; +} + +pub use self::read_buf::read_buf; +pub use self::reader_stream::ReaderStream; +pub use self::stream_reader::StreamReader; +pub use crate::util::{poll_read_buf, poll_write_buf}; diff --git a/third_party/rust/tokio-util/src/io/read_buf.rs b/third_party/rust/tokio-util/src/io/read_buf.rs new file mode 100644 index 0000000000..d7938a3bc1 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/read_buf.rs @@ -0,0 +1,65 @@ +use bytes::BufMut; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. +/// +/// [`BufMut`]: bytes::BufMut +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio_stream as stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, read_buf}; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = read_buf(&mut read, &mut buf).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +pub async fn read_buf<R, B>(read: &mut R, buf: &mut B) -> io::Result<usize> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + return ReadBufFn(read, buf).await; + + struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); + + impl<'a, R, B> Future for ReadBufFn<'a, R, B> + where + R: AsyncRead + Unpin, + B: BufMut, + { + type Output = io::Result<usize>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = &mut *self; + crate::util::poll_read_buf(Pin::new(this.0), cx, this.1) + } + } +} diff --git a/third_party/rust/tokio-util/src/io/reader_stream.rs b/third_party/rust/tokio-util/src/io/reader_stream.rs new file mode 100644 index 0000000000..866c11408d --- /dev/null +++ b/third_party/rust/tokio-util/src/io/reader_stream.rs @@ -0,0 +1,118 @@ +use bytes::{Bytes, BytesMut}; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +const DEFAULT_CAPACITY: usize = 4096; + +pin_project! { + /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. + /// + /// This stream is fused. It performs the inverse operation of + /// [`StreamReader`]. + /// + /// # Example + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// use tokio_stream::StreamExt; + /// use tokio_util::io::ReaderStream; + /// + /// // Create a stream of data. + /// let data = b"hello, world!"; + /// let mut stream = ReaderStream::new(&data[..]); + /// + /// // Read all of the chunks into a vector. + /// let mut stream_contents = Vec::new(); + /// while let Some(chunk) = stream.next().await { + /// stream_contents.extend_from_slice(&chunk?); + /// } + /// + /// // Once the chunks are concatenated, we should have the + /// // original data. + /// assert_eq!(stream_contents, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`StreamReader`]: crate::io::StreamReader + /// [`Stream`]: futures_core::Stream + #[derive(Debug)] + pub struct ReaderStream<R> { + // Reader itself. + // + // This value is `None` if the stream has terminated. + #[pin] + reader: Option<R>, + // Working buffer, used to optimize allocations. + buf: BytesMut, + capacity: usize, + } +} + +impl<R: AsyncRead> ReaderStream<R> { + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result<Bytes, std::io::Error>`. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn new(reader: R) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::new(), + capacity: DEFAULT_CAPACITY, + } + } + + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result<Bytes, std::io::Error>`, + /// with a specific read buffer initial capacity. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn with_capacity(reader: R, capacity: usize) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::with_capacity(capacity), + capacity, + } + } +} + +impl<R: AsyncRead> Stream for ReaderStream<R> { + type Item = std::io::Result<Bytes>; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + use crate::util::poll_read_buf; + + let mut this = self.as_mut().project(); + + let reader = match this.reader.as_pin_mut() { + Some(r) => r, + None => return Poll::Ready(None), + }; + + if this.buf.capacity() == 0 { + this.buf.reserve(*this.capacity); + } + + match poll_read_buf(reader, cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + self.project().reader.set(None); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + self.project().reader.set(None); + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } +} diff --git a/third_party/rust/tokio-util/src/io/stream_reader.rs b/third_party/rust/tokio-util/src/io/stream_reader.rs new file mode 100644 index 0000000000..05ae886557 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/stream_reader.rs @@ -0,0 +1,203 @@ +use bytes::Buf; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +pin_project! { + /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. + /// + /// This type performs the inverse operation of [`ReaderStream`]. + /// + /// # Example + /// + /// ``` + /// use bytes::Bytes; + /// use tokio::io::{AsyncReadExt, Result}; + /// use tokio_util::io::StreamReader; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a stream from an iterator. + /// let stream = tokio_stream::iter(vec![ + /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), + /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), + /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), + /// ]); + /// + /// // Convert it to an AsyncRead. + /// let mut read = StreamReader::new(stream); + /// + /// // Read five bytes from the stream. + /// let mut buf = [0; 5]; + /// read.read_exact(&mut buf).await?; + /// assert_eq!(buf, [0, 1, 2, 3, 4]); + /// + /// // Read the rest of the current chunk. + /// assert_eq!(read.read(&mut buf).await?, 3); + /// assert_eq!(&buf[..3], [5, 6, 7]); + /// + /// // Read the next chunk. + /// assert_eq!(read.read(&mut buf).await?, 4); + /// assert_eq!(&buf[..4], [8, 9, 10, 11]); + /// + /// // We have now reached the end. + /// assert_eq!(read.read(&mut buf).await?, 0); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + /// [`ReaderStream`]: crate::io::ReaderStream + #[derive(Debug)] + pub struct StreamReader<S, B> { + #[pin] + inner: S, + chunk: Option<B>, + } +} + +impl<S, B, E> StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). + /// + /// The item should be a [`Result`] with the ok variant being something that + /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error + /// should be convertible into an [io error]. + /// + /// [`Result`]: std::result::Result + /// [`Buf`]: bytes::Buf + /// [io error]: std::io::Error + pub fn new(stream: S) -> Self { + Self { + inner: stream, + chunk: None, + } + } + + /// Do we have a chunk and is it non-empty? + fn has_chunk(&self) -> bool { + if let Some(ref chunk) = self.chunk { + chunk.remaining() > 0 + } else { + false + } + } + + /// Consumes this `StreamReader`, returning a Tuple consisting + /// of the underlying stream and an Option of the interal buffer, + /// which is Some in case the buffer contains elements. + pub fn into_inner_with_chunk(self) -> (S, Option<B>) { + if self.has_chunk() { + (self.inner, self.chunk) + } else { + (self.inner, None) + } + } +} + +impl<S, B> StreamReader<S, B> { + /// Gets a reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + /// Consumes this `BufWriter`, returning the underlying stream. + /// + /// Note that any leftover data in the internal buffer is lost. + /// If you additionally want access to the internal buffer use + /// [`into_inner_with_chunk`]. + /// + /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk + pub fn into_inner(self) -> S { + self.inner + } +} + +impl<S, B, E> AsyncRead for StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(())) + } +} + +impl<S, B, E> AsyncBufRead for StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + loop { + if self.as_mut().has_chunk() { + // This unwrap is very sad, but it can't be avoided. + let buf = self.project().chunk.as_ref().unwrap().chunk(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + // Go around the loop in case the chunk is empty. + *self.as_mut().project().chunk = Some(chunk); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} diff --git a/third_party/rust/tokio-util/src/io/sync_bridge.rs b/third_party/rust/tokio-util/src/io/sync_bridge.rs new file mode 100644 index 0000000000..9be9446a7d --- /dev/null +++ b/third_party/rust/tokio-util/src/io/sync_bridge.rs @@ -0,0 +1,103 @@ +use std::io::{Read, Write}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or +/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. +#[derive(Debug)] +pub struct SyncIoBridge<T> { + src: T, + rt: tokio::runtime::Handle, +} + +impl<T: AsyncRead + Unpin> Read for SyncIoBridge<T> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(AsyncReadExt::read(src, buf)) + } + + fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.read_to_end(buf)) + } + + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.read_to_string(buf)) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let src = &mut self.src; + // The AsyncRead trait returns the count, synchronous doesn't. + let _n = self.rt.block_on(src.read_exact(buf))?; + Ok(()) + } +} + +impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.write(buf)) + } + + fn flush(&mut self) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.flush()) + } + + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.write_all(buf)) + } + + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.write_vectored(bufs)) + } +} + +// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time +// of this writing still unstable, we expose this as part of a standalone method. +impl<T: AsyncWrite> SyncIoBridge<T> { + /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes. + /// + /// See [`tokio::io::AsyncWrite::is_write_vectored`]. + pub fn is_write_vectored(&self) -> bool { + self.src.is_write_vectored() + } +} + +impl<T: Unpin> SyncIoBridge<T> { + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`]. + /// It is hence OK to move this struct into a separate thread outside the runtime, as created + /// by e.g. [`tokio::task::spawn_blocking`]. + /// + /// Stated even more strongly: to make use of this bridge, you *must* move + /// it into a separate thread outside the runtime. The synchronous I/O will use the + /// underlying handle to block on the backing asynchronous source, via + /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that + /// function, an attempt to `block_on` from an asynchronous execution context + /// will panic. + /// + /// # Wrapping `!Unpin` types + /// + /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. + /// + /// # Panic + /// + /// This will panic if called outside the context of a Tokio runtime. + pub fn new(src: T) -> Self { + Self::new_with_handle(src, tokio::runtime::Handle::current()) + } + + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may + /// be initially invoked outside of an asynchronous context. + pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { + Self { src, rt } + } +} diff --git a/third_party/rust/tokio-util/src/lib.rs b/third_party/rust/tokio-util/src/lib.rs new file mode 100644 index 0000000000..fd14a8ac94 --- /dev/null +++ b/third_party/rust/tokio-util/src/lib.rs @@ -0,0 +1,201 @@ +#![allow(clippy::needless_doctest_main)] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Utilities for working with Tokio. +//! +//! This crate is not versioned in lockstep with the core +//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's +//! semantic versioning policy, especially with regard to breaking changes. +//! +//! [`tokio`]: https://docs.rs/tokio + +#[macro_use] +mod cfg; + +mod loom; + +cfg_codec! { + pub mod codec; +} + +cfg_net! { + pub mod udp; + pub mod net; +} + +cfg_compat! { + pub mod compat; +} + +cfg_io! { + pub mod io; +} + +cfg_rt! { + pub mod context; + pub mod task; +} + +cfg_time! { + pub mod time; +} + +pub mod sync; + +pub mod either; + +#[cfg(any(feature = "io", feature = "codec"))] +mod util { + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + + use bytes::{Buf, BufMut}; + use futures_core::ready; + use std::io::{self, IoSlice}; + use std::mem::MaybeUninit; + use std::pin::Pin; + use std::task::{Context, Poll}; + + /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. + /// + /// [`BufMut`]: bytes::Buf + /// + /// # Example + /// + /// ``` + /// use bytes::{Bytes, BytesMut}; + /// use tokio_stream as stream; + /// use tokio::io::Result; + /// use tokio_util::io::{StreamReader, poll_read_buf}; + /// use futures::future::poll_fn; + /// use std::pin::Pin; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a reader from an iterator. This particular reader will always be + /// // ready. + /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); + /// + /// let mut buf = BytesMut::new(); + /// let mut reads = 0; + /// + /// loop { + /// reads += 1; + /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; + /// + /// if n == 0 { + /// break; + /// } + /// } + /// + /// // one or more reads might be necessary. + /// assert!(reads >= 1); + /// assert_eq!(&buf[..], &[0, 1, 2, 3]); + /// # Ok(()) + /// # } + /// ``` + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_read_buf<T: AsyncRead, B: BufMut>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) + } + + /// Try to write data from an implementer of the [`Buf`] trait to an + /// [`AsyncWrite`], advancing the buffer's internal cursor. + /// + /// This function will use [vectored writes] when the [`AsyncWrite`] supports + /// vectored writes. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements + /// [`Buf`]: + /// + /// ```no_run + /// use tokio_util::io::poll_write_buf; + /// use tokio::io; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// use std::pin::Pin; + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buf = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buf.has_remaining() { + /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`Buf`]: bytes::Buf + /// [`AsyncWrite`]: tokio::io::AsyncWrite + /// [`File`]: tokio::fs::File + /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_write_buf<T: AsyncWrite, B: Buf>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.chunks_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.chunk()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) + } +} diff --git a/third_party/rust/tokio-util/src/loom.rs b/third_party/rust/tokio-util/src/loom.rs new file mode 100644 index 0000000000..dd03feaba1 --- /dev/null +++ b/third_party/rust/tokio-util/src/loom.rs @@ -0,0 +1 @@ +pub(crate) use std::sync; diff --git a/third_party/rust/tokio-util/src/net/mod.rs b/third_party/rust/tokio-util/src/net/mod.rs new file mode 100644 index 0000000000..4817e10d0f --- /dev/null +++ b/third_party/rust/tokio-util/src/net/mod.rs @@ -0,0 +1,97 @@ +//! TCP/UDP/Unix helpers for tokio. + +use crate::either::Either; +use std::future::Future; +use std::io::Result; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[cfg(unix)] +pub mod unix; + +/// A trait for a listener: `TcpListener` and `UnixListener`. +pub trait Listener { + /// The stream's type of this listener. + type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite; + /// The socket address type of this listener. + type Addr; + + /// Polls to accept a new incoming connection to this listener. + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>>; + + /// Accepts a new incoming connection from this listener. + fn accept(&mut self) -> ListenerAcceptFut<'_, Self> + where + Self: Sized, + { + ListenerAcceptFut { listener: self } + } + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> Result<Self::Addr>; +} + +impl Listener for tokio::net::TcpListener { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result<Self::Addr> { + self.local_addr().map(Into::into) + } +} + +/// Future for accepting a new connection from a listener. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ListenerAcceptFut<'a, L> { + listener: &'a mut L, +} + +impl<'a, L> Future for ListenerAcceptFut<'a, L> +where + L: Listener, +{ + type Output = Result<(L::Io, L::Addr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.listener.poll_accept(cx) + } +} + +impl<L, R> Either<L, R> +where + L: Listener, + R: Listener, +{ + /// Accepts a new incoming connection from this listener. + pub async fn accept(&mut self) -> Result<Either<(L::Io, L::Addr), (R::Io, R::Addr)>> { + match self { + Either::Left(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Left((stream, addr))) + } + Either::Right(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Right((stream, addr))) + } + } + } + + /// Returns the local address that this listener is bound to. + pub fn local_addr(&self) -> Result<Either<L::Addr, R::Addr>> { + match self { + Either::Left(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Left(addr)) + } + Either::Right(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Right(addr)) + } + } + } +} diff --git a/third_party/rust/tokio-util/src/net/unix/mod.rs b/third_party/rust/tokio-util/src/net/unix/mod.rs new file mode 100644 index 0000000000..0b522c90a3 --- /dev/null +++ b/third_party/rust/tokio-util/src/net/unix/mod.rs @@ -0,0 +1,18 @@ +//! Unix domain socket helpers. + +use super::Listener; +use std::io::Result; +use std::task::{Context, Poll}; + +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result<Self::Addr> { + self.local_addr().map(Into::into) + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token.rs b/third_party/rust/tokio-util/src/sync/cancellation_token.rs new file mode 100644 index 0000000000..2a6ef392bd --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token.rs @@ -0,0 +1,224 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. +pub(crate) mod guard; +mod tree_node; + +use crate::loom::sync::Arc; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use guard::DropGuard; +use pin_project_lite::pin_project; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::select; +/// use tokio_util::sync::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::sleep(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: Arc<tree_node::TreeNode>, +} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFuture<'a> { + cancellation_token: &'a CancellationToken, + #[pin] + future: tokio::sync::futures::Notified<'a>, + } +} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + tree_node::increase_handle_refcount(&self.inner); + CancellationToken { + inner: self.inner.clone(), + } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + tree_node::decrease_handle_refcount(&self.inner); + } +} + +impl Default for CancellationToken { + fn default() -> CancellationToken { + CancellationToken::new() + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + CancellationToken { + inner: Arc::new(tree_node::TreeNode::new()), + } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::select; + /// use tokio_util::sync::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: tree_node::child_node(&self.inner), + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + /// + /// Be aware that cancellation is not an atomic operation. It is possible + /// for another thread running in parallel with a call to `cancel` to first + /// receive `true` from `is_cancelled` on one child node, and then receive + /// `false` from `is_cancelled` on another child node. However, once the + /// call to `cancel` returns, all child nodes have been fully cancelled. + pub fn cancel(&self) { + tree_node::cancel(&self.inner); + } + + /// Returns `true` if the `CancellationToken` is cancelled. + pub fn is_cancelled(&self) -> bool { + tree_node::is_cancelled(&self.inner) + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: self, + future: self.inner.notified(), + } + } + + /// Creates a `DropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> DropGuard { + DropGuard { inner: Some(self) } + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + this.future.set(this.cancellation_token.inner.notified()); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs new file mode 100644 index 0000000000..54ed7ea2ed --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs @@ -0,0 +1,27 @@ +use crate::sync::CancellationToken; + +/// A wrapper for cancellation token which automatically cancels +/// it on drop. It is created using `drop_guard` method on the `CancellationToken`. +#[derive(Debug)] +pub struct DropGuard { + pub(super) inner: Option<CancellationToken>, +} + +impl DropGuard { + /// Returns stored cancellation token and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(mut self) -> CancellationToken { + self.inner + .take() + .expect("`inner` can be only None in a destructor") + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(inner) = &self.inner { + inner.cancel(); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs new file mode 100644 index 0000000000..b6cd698e23 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs @@ -0,0 +1,373 @@ +//! This mod provides the logic for the inner tree structure of the CancellationToken. +//! +//! CancellationTokens are only light handles with references to TreeNode. +//! All the logic is actually implemented in the TreeNode. +//! +//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of +//! children. +//! +//! A TreeNode can receive the request to perform a cancellation through a CancellationToken. +//! This cancellation request will cancel the node and all of its descendants. +//! +//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no +//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the +//! tree as small as possible. +//! +//! # Invariants +//! +//! Those invariants shall be true at any time. +//! +//! 1. A node that has no parents and no handles can no longer be cancelled. +//! This is important during both cancellation and refcounting. +//! +//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. +//! This is important for deadlock safety, as it is used for lock order. +//! Node B can only become the child of node A in two ways: +//! - being created with `child_node()`, in which case it is trivially true that +//! node A already existed when node B was created +//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` +//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C +//! was younger than A, therefore B is also younger than A. +//! +//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of +//! node A. It is important to always restore that invariant before dropping the lock of a node. +//! +//! # Deadlock safety +//! +//! We always lock in the order of creation time. We can prove this through invariant #2. +//! Specifically, through invariant #2, we know that we always have to lock a parent +//! before its child. +//! +use crate::loom::sync::{Arc, Mutex, MutexGuard}; + +/// A node of the cancellation tree structure +/// +/// The actual data it holds is wrapped inside a mutex for synchronization. +pub(crate) struct TreeNode { + inner: Mutex<Inner>, + waker: tokio::sync::Notify, +} +impl TreeNode { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + } + } + + pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { + self.waker.notified() + } +} + +/// The data contained inside a TreeNode. +/// +/// This struct exists so that the data of the node can be wrapped +/// in a Mutex. +struct Inner { + parent: Option<Arc<TreeNode>>, + parent_idx: usize, + children: Vec<Arc<TreeNode>>, + is_cancelled: bool, + num_handles: usize, +} + +/// Returns whether or not the node is cancelled +pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool { + node.inner.lock().unwrap().is_cancelled +} + +/// Creates a child node +pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> { + let mut locked_parent = parent.inner.lock().unwrap(); + + // Do not register as child if we are already cancelled. + // Cancelled trees can never be uncancelled and therefore + // need no connection to parents or children any more. + if locked_parent.is_cancelled { + return Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: true, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + } + + let child = Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: Some(parent.clone()), + parent_idx: locked_parent.children.len(), + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + + locked_parent.children.push(child.clone()); + + child +} + +/// Disconnects the given parent from all of its children. +/// +/// Takes a reference to [Inner] to make sure the parent is already locked. +fn disconnect_children(node: &mut Inner) { + for child in std::mem::take(&mut node.children) { + let mut locked_child = child.inner.lock().unwrap(); + locked_child.parent_idx = 0; + locked_child.parent = None; + } +} + +/// Figures out the parent of the node and locks the node and its parent atomically. +/// +/// The basic principle of preventing deadlocks in the tree is +/// that we always lock the parent first, and then the child. +/// For more info look at *deadlock safety* and *invariant #2*. +/// +/// Sadly, it's impossible to figure out the parent of a node without +/// locking it. To then achieve locking order consistency, the node +/// has to be unlocked before the parent gets locked. +/// This leaves a small window where we already assume that we know the parent, +/// but neither the parent nor the node is locked. Therefore, the parent could change. +/// +/// To prevent that this problem leaks into the rest of the code, it is abstracted +/// in this function. +/// +/// The locked child and optionally its locked parent, if a parent exists, get passed +/// to the `func` argument via (node, None) or (node, Some(parent)). +fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret +where + F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret, +{ + let mut potential_parent = { + let locked_node = node.inner.lock().unwrap(); + match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => return func(locked_node, None), + } + }; + + loop { + // Deadlock safety: + // + // Due to invariant #2, we know that we have to lock the parent first, and then the child. + // This is true even if the potential_parent is no longer the current parent or even its + // sibling, as the invariant still holds. + let locked_parent = potential_parent.inner.lock().unwrap(); + let locked_node = node.inner.lock().unwrap(); + + let actual_parent = match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => { + // Was the wrong parent, so unlock it before calling `func` + drop(locked_parent); + return func(locked_node, None); + } + }; + + // Loop until we managed to lock both the node and its parent + if Arc::ptr_eq(&actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } + + // Drop locked_parent before reassigning to potential_parent, + // as potential_parent is borrowed in it + drop(locked_node); + drop(locked_parent); + + potential_parent = actual_parent; + } +} + +/// Moves all children from `node` to `parent`. +/// +/// `parent` MUST have been a parent of the node when they both got locked, +/// otherwise there is a potential for a deadlock as invariant #2 would be violated. +/// +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { + // Pre-allocate in the parent, for performance + parent.children.reserve(node.children.len()); + + for child in std::mem::take(&mut node.children) { + { + let mut child_locked = child.inner.lock().unwrap(); + child_locked.parent = node.parent.clone(); + child_locked.parent_idx = parent.children.len(); + } + parent.children.push(child); + } +} + +/// Removes a child from the parent. +/// +/// `parent` MUST be the parent of `node`. +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { + // Query the position from where to remove a node + let pos = node.parent_idx; + node.parent = None; + node.parent_idx = 0; + + // Unlock node, so that only one child at a time is locked. + // Otherwise we would violate the lock order (see 'deadlock safety') as we + // don't know the creation order of the child nodes + drop(node); + + // If `node` is the last element in the list, we don't need any swapping + if parent.children.len() == pos + 1 { + parent.children.pop().unwrap(); + } else { + // If `node` is not the last element in the list, we need to + // replace it with the last element + let replacement_child = parent.children.pop().unwrap(); + replacement_child.inner.lock().unwrap().parent_idx = pos; + parent.children[pos] = replacement_child; + } + + let len = parent.children.len(); + if 4 * len <= parent.children.capacity() { + // equal to: + // parent.children.shrink_to(2 * len); + // but shrink_to was not yet stabilized in our minimal compatible version + let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len)); + parent.children.extend(old_children); + } +} + +/// Increases the reference count of handles. +pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + // Once no handles are left over, the node gets detached from the tree. + // There should never be a new handle once all handles are dropped. + assert!(locked_node.num_handles > 0); + + locked_node.num_handles += 1; +} + +/// Decreases the reference count of handles. +/// +/// Once no handle is left, we can remove the node from the +/// tree and connect its parent directly to its children. +pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) { + let num_handles = { + let mut locked_node = node.inner.lock().unwrap(); + locked_node.num_handles -= 1; + locked_node.num_handles + }; + + if num_handles == 0 { + with_locked_node_and_parent(node, |mut node, parent| { + // Remove the node from the tree + match parent { + Some(mut parent) => { + // As we want to remove ourselves from the tree, + // we have to move the children to the parent, so that + // they still receive the cancellation event without us. + // Moving them does not violate invariant #1. + move_children_to_parent(&mut node, &mut parent); + + // Remove the node from the parent + remove_child(&mut parent, node); + } + None => { + // Due to invariant #1, we can assume that our + // children can no longer be cancelled through us. + // (as we now have neither a parent nor handles) + // Therefore we can disconnect them. + disconnect_children(&mut node); + } + } + }); + } +} + +/// Cancels a node and its children. +pub(crate) fn cancel(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + if locked_node.is_cancelled { + return; + } + + // One by one, adopt grandchildren and then cancel and detach the child + while let Some(child) = locked_node.children.pop() { + // This can't deadlock because the mutex we are already + // holding is the parent of child. + let mut locked_child = child.inner.lock().unwrap(); + + // Detach the child from node + // No need to modify node.children, as the child already got removed with `.pop` + locked_child.parent = None; + locked_child.parent_idx = 0; + + // If child is already cancelled, detaching is enough + if locked_child.is_cancelled { + continue; + } + + // Cancel or adopt grandchildren + while let Some(grandchild) = locked_child.children.pop() { + // This can't deadlock because the two mutexes we are already + // holding is the parent and grandparent of grandchild. + let mut locked_grandchild = grandchild.inner.lock().unwrap(); + + // Detach the grandchild + locked_grandchild.parent = None; + locked_grandchild.parent_idx = 0; + + // If grandchild is already cancelled, detaching is enough + if locked_grandchild.is_cancelled { + continue; + } + + // For performance reasons, only adopt grandchildren that have children. + // Otherwise, just cancel them right away, no need for another iteration. + if locked_grandchild.children.is_empty() { + // Cancel the grandchild + locked_grandchild.is_cancelled = true; + locked_grandchild.children = Vec::new(); + drop(locked_grandchild); + grandchild.waker.notify_waiters(); + } else { + // Otherwise, adopt grandchild + locked_grandchild.parent = Some(node.clone()); + locked_grandchild.parent_idx = locked_node.children.len(); + drop(locked_grandchild); + locked_node.children.push(grandchild); + } + } + + // Cancel the child + locked_child.is_cancelled = true; + locked_child.children = Vec::new(); + drop(locked_child); + child.waker.notify_waiters(); + + // Now the child is cancelled and detached and all its children are adopted. + // Just continue until all (including adopted) children are cancelled and detached. + } + + // Cancel the node itself. + locked_node.is_cancelled = true; + locked_node.children = Vec::new(); + drop(locked_node); + node.waker.notify_waiters(); +} diff --git a/third_party/rust/tokio-util/src/sync/mod.rs b/third_party/rust/tokio-util/src/sync/mod.rs new file mode 100644 index 0000000000..de392f0bb1 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mod.rs @@ -0,0 +1,13 @@ +//! Synchronization primitives + +mod cancellation_token; +pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancellationFuture}; + +mod mpsc; +pub use mpsc::{PollSendError, PollSender}; + +mod poll_semaphore; +pub use poll_semaphore::PollSemaphore; + +mod reusable_box; +pub use reusable_box::ReusableBoxFuture; diff --git a/third_party/rust/tokio-util/src/sync/mpsc.rs b/third_party/rust/tokio-util/src/sync/mpsc.rs new file mode 100644 index 0000000000..34a47c1891 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mpsc.rs @@ -0,0 +1,283 @@ +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, mem}; +use tokio::sync::mpsc::OwnedPermit; +use tokio::sync::mpsc::Sender; + +use super::ReusableBoxFuture; + +/// Error returned by the `PollSender` when the channel is closed. +#[derive(Debug)] +pub struct PollSendError<T>(Option<T>); + +impl<T> PollSendError<T> { + /// Consumes the stored value, if any. + /// + /// If this error was encountered when calling `start_send`/`send_item`, this will be the item + /// that the caller attempted to send. Otherwise, it will be `None`. + pub fn into_inner(self) -> Option<T> { + self.0 + } +} + +impl<T> fmt::Display for PollSendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} + +#[derive(Debug)] +enum State<T> { + Idle(Sender<T>), + Acquiring, + ReadyToSend(OwnedPermit<T>), + Closed, +} + +/// A wrapper around [`mpsc::Sender`] that can be polled. +/// +/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender +#[derive(Debug)] +pub struct PollSender<T> { + sender: Option<Sender<T>>, + state: State<T>, + acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>, +} + +// Creates a future for acquiring a permit from the underlying channel. This is used to ensure +// there's capacity for a send to complete. +// +// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to +// ReusableBoxFuture has the same underlying type, and hence the same size and alignment. +async fn make_acquire_future<T>( + data: Option<Sender<T>>, +) -> Result<OwnedPermit<T>, PollSendError<T>> { + match data { + Some(sender) => sender + .reserve_owned() + .await + .map_err(|_| PollSendError(None)), + None => unreachable!("this future should not be pollable in this state"), + } +} + +impl<T: Send + 'static> PollSender<T> { + /// Creates a new `PollSender`. + pub fn new(sender: Sender<T>) -> Self { + Self { + sender: Some(sender.clone()), + state: State::Idle(sender), + acquire: ReusableBoxFuture::new(make_acquire_future(None)), + } + } + + fn take_state(&mut self) -> State<T> { + mem::replace(&mut self.state, State::Closed) + } + + /// Attempts to prepare the sender to receive a value. + /// + /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to + /// `send_item`. + /// + /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, + /// by reserving a slot in the channel for the item to be sent. If this method returns + /// `Poll::Pending`, the current task is registered to be notified (via + /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { + loop { + let (result, next_state) = match self.take_state() { + State::Idle(sender) => { + // Start trying to acquire a permit to reserve a slot for our send, and + // immediately loop back around to poll it the first time. + self.acquire.set(make_acquire_future(Some(sender))); + (None, State::Acquiring) + } + State::Acquiring => match self.acquire.poll(cx) { + // Channel has capacity. + Poll::Ready(Ok(permit)) => { + (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) + } + // Channel is closed. + Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), + // Channel doesn't have capacity yet, so we need to wait. + Poll::Pending => (Some(Poll::Pending), State::Acquiring), + }, + // We're closed, either by choice or because the underlying sender was closed. + s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), + // We're already ready to send an item. + s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), + }; + + self.state = next_state; + if let Some(result) = result { + return result; + } + } + } + + /// Sends an item to the channel. + /// + /// Before calling `send_item`, `poll_reserve` must be called with a successful return + /// value of `Poll::Ready(Ok(()))`. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + /// + /// # Panics + /// + /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method + /// will panic. + pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { + let (result, next_state) = match self.take_state() { + State::Idle(_) | State::Acquiring => { + panic!("`send_item` called without first calling `poll_reserve`") + } + // We have a permit to send our item, so go ahead, which gets us our sender back. + State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), + // We're closed, either by choice or because the underlying sender was closed. + State::Closed => (Err(PollSendError(Some(value))), State::Closed), + }; + + // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. + self.state = if self.sender.is_some() { + next_state + } else { + State::Closed + }; + result + } + + /// Checks whether this sender is been closed. + /// + /// The underlying channel that this sender was wrapping may still be open. + pub fn is_closed(&self) -> bool { + matches!(self.state, State::Closed) || self.sender.is_none() + } + + /// Gets a reference to the `Sender` of the underlying channel. + /// + /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender + /// was wrapping may still be open. + pub fn get_ref(&self) -> Option<&Sender<T>> { + self.sender.as_ref() + } + + /// Closes this sender. + /// + /// No more messages will be able to be sent from this sender, but the underlying channel will + /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. + /// + /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made + /// to `send_item` in order to consume the reserved slot. After that, no further sends will be + /// possible. If you do not intend to send another item, you can release the reserved slot back + /// to the underlying sender by calling [`abort_send`]. + /// + /// [`abort_send`]: crate::sync::PollSender::abort_send + /// [`Receiver`]: tokio::sync::mpsc::Receiver + pub fn close(&mut self) { + // Mark ourselves officially closed by dropping our main sender. + self.sender = None; + + // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly + // transition to the closed state. Otherwise, leave the existing permit in place for the + // caller if they want to complete the send. + match self.state { + State::Idle(_) => self.state = State::Closed, + State::Acquiring => { + self.acquire.set(make_acquire_future(None)); + self.state = State::Closed; + } + _ => {} + } + } + + /// Aborts the current in-progress send, if any. + /// + /// Returns `true` if a send was aborted. If the sender was closed prior to calling + /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be + /// ready to attempt another send. + pub fn abort_send(&mut self) -> bool { + // We may have been closed in the meantime, after a call to `poll_reserve` already + // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the + // closed state when we actually abort a send, rather than resetting ourselves back to idle. + + let (result, next_state) = match self.take_state() { + // We're currently trying to reserve a slot to send into. + State::Acquiring => { + // Replacing the future drops the in-flight one. + self.acquire.set(make_acquire_future(None)); + + // If we haven't closed yet, we have to clone our stored sender since we have no way + // to get it back from the acquire future we just dropped. + let state = match self.sender.clone() { + Some(sender) => State::Idle(sender), + None => State::Closed, + }; + (true, state) + } + // We got the permit. If we haven't closed yet, get the sender back. + State::ReadyToSend(permit) => { + let state = if self.sender.is_some() { + State::Idle(permit.release()) + } else { + State::Closed + }; + (true, state) + } + s => (false, s), + }; + + self.state = next_state; + result + } +} + +impl<T> Clone for PollSender<T> { + /// Clones this `PollSender`. + /// + /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. + fn clone(&self) -> PollSender<T> { + let (sender, state) = match self.sender.clone() { + Some(sender) => (Some(sender.clone()), State::Idle(sender)), + None => (None, State::Closed), + }; + + Self { + sender, + state, + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender<T>`. + acquire: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl<T: Send + 'static> Sink<T> for PollSender<T> { + type Error = PollSendError<T>; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).poll_reserve(cx) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::into_inner(self).send_item(item) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).close(); + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio-util/src/sync/poll_semaphore.rs b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs new file mode 100644 index 0000000000..d0b1dedc27 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs @@ -0,0 +1,136 @@ +use futures_core::{ready, Stream}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +use super::ReusableBoxFuture; + +/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. +/// +/// [`Semaphore`]: tokio::sync::Semaphore +pub struct PollSemaphore { + semaphore: Arc<Semaphore>, + permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>, +} + +impl PollSemaphore { + /// Create a new `PollSemaphore`. + pub fn new(semaphore: Arc<Semaphore>) -> Self { + Self { + semaphore, + permit_fut: None, + } + } + + /// Closes the semaphore. + pub fn close(&self) { + self.semaphore.close() + } + + /// Obtain a clone of the inner semaphore. + pub fn clone_inner(&self) -> Arc<Semaphore> { + self.semaphore.clone() + } + + /// Get back the inner semaphore. + pub fn into_inner(self) -> Arc<Semaphore> { + self.semaphore + } + + /// Poll to acquire a permit from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when a permit becomes available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + let permit_future = match self.permit_fut.as_mut() { + Some(fut) => fut, + None => { + // avoid allocations completely if we can grab a permit immediately + match Arc::clone(&self.semaphore).try_acquire_owned() { + Ok(permit) => return Poll::Ready(Some(permit)), + Err(TryAcquireError::Closed) => return Poll::Ready(None), + Err(TryAcquireError::NoPermits) => {} + } + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + self.permit_fut + .get_or_insert(ReusableBoxFuture::new(next_fut)) + } + }; + + let result = ready!(permit_future.poll(cx)); + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + permit_future.set(next_fut); + + match result { + Ok(permit) => Poll::Ready(Some(permit)), + Err(_closed) => { + self.permit_fut = None; + Poll::Ready(None) + } + } + } + + /// Returns the current number of available permits. + /// + /// This is equivalent to the [`Semaphore::available_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function + /// will panic if the limit is exceeded. + /// + /// This is equivalent to the [`Semaphore::add_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits + pub fn add_permits(&self, n: usize) { + self.semaphore.add_permits(n); + } +} + +impl Stream for PollSemaphore { + type Item = OwnedSemaphorePermit; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + Pin::into_inner(self).poll_acquire(cx) + } +} + +impl Clone for PollSemaphore { + fn clone(&self) -> PollSemaphore { + PollSemaphore::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollSemaphore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollSemaphore") + .field("semaphore", &self.semaphore) + .finish() + } +} + +impl AsRef<Semaphore> for PollSemaphore { + fn as_ref(&self) -> &Semaphore { + &*self.semaphore + } +} diff --git a/third_party/rust/tokio-util/src/sync/reusable_box.rs b/third_party/rust/tokio-util/src/sync/reusable_box.rs new file mode 100644 index 0000000000..3204207db7 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/reusable_box.rs @@ -0,0 +1,148 @@ +use std::alloc::Layout; +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::ptr::{self, NonNull}; +use std::task::{Context, Poll}; +use std::{fmt, panic}; + +/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub struct ReusableBoxFuture<'a, T> { + boxed: NonNull<dyn Future<Output = T> + Send + 'a>, +} + +impl<'a, T> ReusableBoxFuture<'a, T> { + /// Create a new `ReusableBoxFuture<T>` containing the provided future. + pub fn new<F>(future: F) -> Self + where + F: Future<Output = T> + Send + 'a, + { + let boxed: Box<dyn Future<Output = T> + Send + 'a> = Box::new(future); + + let boxed = NonNull::from(Box::leak(boxed)); + + Self { boxed } + } + + /// Replace the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub fn set<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replace the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub fn try_set<F>(&mut self, future: F) -> Result<(), F> + where + F: Future<Output = T> + Send + 'a, + { + // SAFETY: The pointer is not dangling. + let self_layout = { + let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() }; + Layout::for_value(dyn_future) + }; + + if Layout::new::<F>() == self_layout { + // SAFETY: We just checked that the layout of F is correct. + unsafe { + self.set_same_layout(future); + } + + Ok(()) + } else { + Err(future) + } + } + + /// Set the current future. + /// + /// # Safety + /// + /// This function requires that the layout of the provided future is the + /// same as `self.layout`. + unsafe fn set_same_layout<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + // Drop the existing future, catching any panics. + let result = panic::catch_unwind(AssertUnwindSafe(|| { + ptr::drop_in_place(self.boxed.as_ptr()); + })); + + // Overwrite the future behind the pointer. This is safe because the + // allocation was allocated with the same size and alignment as the type F. + let self_ptr: *mut F = self.boxed.as_ptr() as *mut F; + ptr::write(self_ptr, future); + + // Update the vtable of self.boxed. The pointer is not null because we + // just got it from self.boxed, which is not null. + self.boxed = NonNull::new_unchecked(self_ptr); + + // If the old future's destructor panicked, resume unwinding. + match result { + Ok(()) => {} + Err(payload) => { + panic::resume_unwind(payload); + } + } + } + + /// Get a pinned reference to the underlying future. + pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> { + // SAFETY: The user of this box cannot move the box, and we do not move it + // either. + unsafe { Pin::new_unchecked(self.boxed.as_mut()) } + } + + /// Poll the future stored inside this box. + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> { + self.get_pin().poll(cx) + } +} + +impl<T> Future for ReusableBoxFuture<'_, T> { + type Output = T; + + /// Poll the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The future stored inside ReusableBoxFuture<'_, T> must be Send. +unsafe impl<T> Send for ReusableBoxFuture<'_, T> {} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {} + +// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type. +impl<T> Unpin for ReusableBoxFuture<'_, T> {} + +impl<T> Drop for ReusableBoxFuture<'_, T> { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.boxed.as_ptr())); + } + } +} + +impl<T> fmt::Debug for ReusableBoxFuture<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} diff --git a/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 0000000000..e9c9f3dd98 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,155 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_childs() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/third_party/rust/tokio-util/src/sync/tests/mod.rs b/third_party/rust/tokio-util/src/sync/tests/mod.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/mod.rs @@ -0,0 +1 @@ + diff --git a/third_party/rust/tokio-util/src/task/mod.rs b/third_party/rust/tokio-util/src/task/mod.rs new file mode 100644 index 0000000000..5aa33df2dc --- /dev/null +++ b/third_party/rust/tokio-util/src/task/mod.rs @@ -0,0 +1,4 @@ +//! Extra utilities for spawning tasks + +mod spawn_pinned; +pub use spawn_pinned::LocalPoolHandle; diff --git a/third_party/rust/tokio-util/src/task/spawn_pinned.rs b/third_party/rust/tokio-util/src/task/spawn_pinned.rs new file mode 100644 index 0000000000..6f553e9d07 --- /dev/null +++ b/third_party/rust/tokio-util/src/task/spawn_pinned.rs @@ -0,0 +1,307 @@ +use futures_util::future::{AbortHandle, Abortable}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::runtime::Builder; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::{spawn_local, JoinHandle, LocalSet}; + +/// A handle to a local pool, used for spawning `!Send` tasks. +#[derive(Clone)] +pub struct LocalPoolHandle { + pool: Arc<LocalPool>, +} + +impl LocalPoolHandle { + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this + /// pool via [`LocalPoolHandle::spawn_pinned`]. + /// + /// # Panics + /// Panics if the pool size is less than one. + pub fn new(pool_size: usize) -> LocalPoolHandle { + assert!(pool_size > 0); + + let workers = (0..pool_size) + .map(|_| LocalWorkerHandle::new_worker()) + .collect(); + + let pool = Arc::new(LocalPool { workers }); + + LocalPoolHandle { pool } + } + + /// Spawn a task onto a worker thread and pin it there so it can't be moved + /// off of the thread. Note that the future is not [`Send`], but the + /// [`FnOnce`] which creates it is. + /// + /// # Examples + /// ``` + /// use std::rc::Rc; + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create the local pool + /// let pool = LocalPoolHandle::new(1); + /// + /// // Spawn a !Send future onto the pool and await it + /// let output = pool + /// .spawn_pinned(|| { + /// // Rc is !Send + !Sync + /// let local_data = Rc::new("test"); + /// + /// // This future holds an Rc, so it is !Send + /// async move { local_data.to_string() } + /// }) + /// .await + /// .unwrap(); + /// + /// assert_eq!(output, "test"); + /// } + /// ``` + pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool.spawn_pinned(create_task) + } +} + +impl Debug for LocalPoolHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("LocalPoolHandle") + } +} + +struct LocalPool { + workers: Vec<LocalWorkerHandle>, +} + +impl LocalPool { + /// Spawn a `?Send` future onto a worker + fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + + let (worker, job_guard) = self.find_and_incr_least_burdened_worker(); + let worker_spawner = worker.spawner.clone(); + + // Spawn a future onto the worker's runtime so we can immediately return + // a join handle. + worker.runtime_handle.spawn(async move { + // Move the job guard into the task + let _job_guard = job_guard; + + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + + // Inside the future we can't run spawn_local yet because we're not + // in the context of a LocalSet. We need to send create_task to the + // LocalSet task for spawning. + let spawn_task = Box::new(move || { + // Once we're in the LocalSet context we can call spawn_local + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); + + // Send the join handle back to the spawner. If sending fails, + // we assume the parent task was canceled, so cancel this task + // as well. + if let Err(join_handle) = sender.send(join_handle) { + join_handle.abort() + } + }); + + // Send the callback to the LocalSet task + if let Err(e) = worker_spawner.send(spawn_task) { + // Propagate the error as a panic in the join handle. + panic!("Failed to send job to worker: {}", e); + } + + // Wait for the task's join handle + let join_handle = match receiver.await { + Ok(handle) => handle, + Err(e) => { + // We sent the task successfully, but failed to get its + // join handle... We assume something happened to the worker + // and the task was not spawned. Propagate the error as a + // panic in the join handle. + panic!("Worker failed to send join handle: {}", e); + } + }; + + // Wait for the task to complete + let join_result = join_handle.await; + + match join_result { + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else if e.is_cancelled() { + // No one else should have the join handle, so this is + // unexpected. Forward this error as a panic in the join + // handle. + panic!("spawn_pinned task was canceled: {}", e); + } else { + // Something unknown happened (not a panic or + // cancellation). Forward this error as a panic in the + // join handle. + panic!("spawn_pinned task failed: {}", e); + } + } + } + }) + } + + /// Find the worker with the least number of tasks, increment its task + /// count, and return its handle. Make sure to actually spawn a task on + /// the worker so the task count is kept consistent with load. + /// + /// A job count guard is also returned to ensure the task count gets + /// decremented when the job is done. + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { + loop { + let (worker, task_count) = self + .workers + .iter() + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) + .min_by_key(|&(_, count)| count) + .expect("There must be more than one worker"); + + // Make sure the task count hasn't changed since when we choose this + // worker. Otherwise, restart the search. + if worker + .task_count + .compare_exchange( + task_count, + task_count + 1, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); + } + } + } +} + +/// Automatically decrements a worker's job count when a job finishes (when +/// this gets dropped). +struct JobCountGuard(Arc<AtomicUsize>); + +impl Drop for JobCountGuard { + fn drop(&mut self) { + // Decrement the job count + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); + debug_assert!(previous_value >= 1); + } +} + +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); + +impl Drop for AbortGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; + +struct LocalWorkerHandle { + runtime_handle: tokio::runtime::Handle, + spawner: UnboundedSender<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, +} + +impl LocalWorkerHandle { + /// Create a new worker for executing pinned tasks + fn new_worker() -> LocalWorkerHandle { + let (sender, receiver) = unbounded_channel(); + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to start a pinned worker thread runtime"); + let runtime_handle = runtime.handle().clone(); + let task_count = Arc::new(AtomicUsize::new(0)); + let task_count_clone = Arc::clone(&task_count); + + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + + LocalWorkerHandle { + runtime_handle, + spawner: sender, + task_count, + } + } + + fn run( + runtime: tokio::runtime::Runtime, + mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, + ) { + let local_set = LocalSet::new(); + local_set.block_on(&runtime, async { + while let Some(spawn_task) = task_receiver.recv().await { + // Calls spawn_local(future) + (spawn_task)(); + } + }); + + // If there are any tasks on the runtime associated with a LocalSet task + // that has already completed, but whose output has not yet been + // reported, let that task complete. + // + // Since the task_count is decremented when the runtime task exits, + // reading that counter lets us know if any such tasks completed during + // the call to `block_on`. + // + // Tasks on the LocalSet can't complete during this loop since they're + // stored on the LocalSet and we aren't accessing it. + let mut previous_task_count = task_count.load(Ordering::SeqCst); + loop { + // This call will also run tasks spawned on the runtime. + runtime.block_on(tokio::task::yield_now()); + let new_task_count = task_count.load(Ordering::SeqCst); + if new_task_count == previous_task_count { + break; + } else { + previous_task_count = new_task_count; + } + } + + // It's now no longer possible for a task on the runtime to be + // associated with a LocalSet task that has completed. Drop both the + // LocalSet and runtime to let tasks on the runtime be cancelled if and + // only if they are still on the LocalSet. + // + // Drop the LocalSet task first so that anyone awaiting the runtime + // JoinHandle will see the cancelled error after the LocalSet task + // destructor has completed. + drop(local_set); + drop(runtime); + } +} diff --git a/third_party/rust/tokio-util/src/time/delay_queue.rs b/third_party/rust/tokio-util/src/time/delay_queue.rs new file mode 100644 index 0000000000..a0c5e5c5b0 --- /dev/null +++ b/third_party/rust/tokio-util/src/time/delay_queue.rs @@ -0,0 +1,1221 @@ +//! A queue of delayed elements. +//! +//! See [`DelayQueue`] for more details. +//! +//! [`DelayQueue`]: struct@DelayQueue + +use crate::time::wheel::{self, Wheel}; + +use futures_core::ready; +use tokio::time::{sleep_until, Duration, Instant, Sleep}; + +use core::ops::{Index, IndexMut}; +use slab::Slab; +use std::cmp; +use std::collections::HashMap; +use std::convert::From; +use std::fmt; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{self, Poll, Waker}; + +/// A queue of delayed elements. +/// +/// Once an element is inserted into the `DelayQueue`, it is yielded once the +/// specified deadline has been reached. +/// +/// # Usage +/// +/// Elements are inserted into `DelayQueue` using the [`insert`] or +/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is +/// returned. The key is used to remove the entry or to change the deadline at +/// which it should be yielded back. +/// +/// Once delays have been configured, the `DelayQueue` is used via its +/// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its +/// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the +/// current task will be notified once the deadline has been reached. +/// +/// # `Stream` implementation +/// +/// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have +/// expired, no items are returned. In this case, `Poll::Pending` is returned and the +/// current task is registered to be notified once the next item's delay has +/// expired. +/// +/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` +/// returns `Poll::Ready(None)`. This indicates that the stream has reached an end. +/// However, if a new item is inserted *after*, `poll` will once again start +/// returning items or `Poll::Pending`. +/// +/// Items are returned ordered by their expirations. Items that are configured +/// to expire first will be returned first. There are no ordering guarantees +/// for items configured to expire at the same instant. Also note that delays are +/// rounded to the closest millisecond. +/// +/// # Implementation +/// +/// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally +/// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same +/// performance and scalability benefits. +/// +/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, +/// and allows reuse of the memory allocated for expired entires. +/// +/// Capacity can be checked using [`capacity`] and allocated preemptively by using +/// the [`reserve`] method. +/// +/// # Usage +/// +/// Using `DelayQueue` to manage cache entries. +/// +/// ```rust,no_run +/// use tokio_util::time::{DelayQueue, delay_queue}; +/// +/// use futures::ready; +/// use std::collections::HashMap; +/// use std::task::{Context, Poll}; +/// use std::time::Duration; +/// # type CacheKey = String; +/// # type Value = String; +/// +/// struct Cache { +/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>, +/// expirations: DelayQueue<CacheKey>, +/// } +/// +/// const TTL_SECS: u64 = 30; +/// +/// impl Cache { +/// fn insert(&mut self, key: CacheKey, value: Value) { +/// let delay = self.expirations +/// .insert(key.clone(), Duration::from_secs(TTL_SECS)); +/// +/// self.entries.insert(key, (value, delay)); +/// } +/// +/// fn get(&self, key: &CacheKey) -> Option<&Value> { +/// self.entries.get(key) +/// .map(|&(ref v, _)| v) +/// } +/// +/// fn remove(&mut self, key: &CacheKey) { +/// if let Some((_, cache_key)) = self.entries.remove(key) { +/// self.expirations.remove(&cache_key); +/// } +/// } +/// +/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> { +/// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) { +/// self.entries.remove(entry.get_ref()); +/// } +/// +/// Poll::Ready(()) +/// } +/// } +/// ``` +/// +/// [`insert`]: method@Self::insert +/// [`insert_at`]: method@Self::insert_at +/// [`Key`]: struct@Key +/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html +/// [`poll_expired`]: method@Self::poll_expired +/// [`Stream::poll_expired`]: method@Self::poll_expired +/// [`DelayQueue`]: struct@DelayQueue +/// [`sleep`]: fn@tokio::time::sleep +/// [`slab`]: slab +/// [`capacity`]: method@Self::capacity +/// [`reserve`]: method@Self::reserve +#[derive(Debug)] +pub struct DelayQueue<T> { + /// Stores data associated with entries + slab: SlabStorage<T>, + + /// Lookup structure tracking all delays in the queue + wheel: Wheel<Stack<T>>, + + /// Delays that were inserted when already expired. These cannot be stored + /// in the wheel + expired: Stack<T>, + + /// Delay expiring when the *first* item in the queue expires + delay: Option<Pin<Box<Sleep>>>, + + /// Wheel polling state + wheel_now: u64, + + /// Instant at which the timer starts + start: Instant, + + /// Waker that is invoked when we potentially need to reset the timer. + /// Because we lazily create the timer when the first entry is created, we + /// need to awaken any poller that polled us before that point. + waker: Option<Waker>, +} + +#[derive(Default)] +struct SlabStorage<T> { + inner: Slab<Data<T>>, + + // A `compact` call requires a re-mapping of the `Key`s that were changed + // during the `compact` call of the `slab`. Since the keys that were given out + // cannot be changed retroactively we need to keep track of these re-mappings. + // The keys of `key_map` correspond to the old keys that were given out and + // the values to the `Key`s that were re-mapped by the `compact` call. + key_map: HashMap<Key, KeyInternal>, + + // Index used to create new keys to hand out. + next_key_index: usize, + + // Whether `compact` has been called, necessary in order to decide whether + // to include keys in `key_map`. + compact_called: bool, +} + +impl<T> SlabStorage<T> { + pub(crate) fn with_capacity(capacity: usize) -> SlabStorage<T> { + SlabStorage { + inner: Slab::with_capacity(capacity), + key_map: HashMap::new(), + next_key_index: 0, + compact_called: false, + } + } + + // Inserts data into the inner slab and re-maps keys if necessary + pub(crate) fn insert(&mut self, val: Data<T>) -> Key { + let mut key = KeyInternal::new(self.inner.insert(val)); + let key_contained = self.key_map.contains_key(&key.into()); + + if key_contained { + // It's possible that a `compact` call creates capacitiy in `self.inner` in + // such a way that a `self.inner.insert` call creates a `key` which was + // previously given out during an `insert` call prior to the `compact` call. + // If `key` is contained in `self.key_map`, we have encountered this exact situation, + // We need to create a new key `key_to_give_out` and include the relation + // `key_to_give_out` -> `key` in `self.key_map`. + let key_to_give_out = self.create_new_key(); + assert!(!self.key_map.contains_key(&key_to_give_out.into())); + self.key_map.insert(key_to_give_out.into(), key); + key = key_to_give_out; + } else if self.compact_called { + // Include an identity mapping in `self.key_map` in order to allow us to + // panic if a key that was handed out is removed more than once. + self.key_map.insert(key.into(), key); + } + + key.into() + } + + // Re-map the key in case compact was previously called. + // Note: Since we include identity mappings in key_map after compact was called, + // we have information about all keys that were handed out. In the case in which + // compact was called and we try to remove a Key that was previously removed + // we can detect invalid keys if no key is found in `key_map`. This is necessary + // in order to prevent situations in which a previously removed key + // corresponds to a re-mapped key internally and which would then be incorrectly + // removed from the slab. + // + // Example to illuminate this problem: + // + // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we + // were to remove 1 again, we would not find it inside `key_map` anymore. + // If we were to imply from this that no re-mapping was necessary, we would + // incorrectly remove 1 from `self.slab.inner`, which corresponds to the + // handed-out key 2. + pub(crate) fn remove(&mut self, key: &Key) -> Data<T> { + let remapped_key = if self.compact_called { + match self.key_map.remove(key) { + Some(key_internal) => key_internal, + None => panic!("invalid key"), + } + } else { + (*key).into() + }; + + self.inner.remove(remapped_key.index) + } + + pub(crate) fn shrink_to_fit(&mut self) { + self.inner.shrink_to_fit(); + self.key_map.shrink_to_fit(); + } + + pub(crate) fn compact(&mut self) { + if !self.compact_called { + for (key, _) in self.inner.iter() { + self.key_map.insert(Key::new(key), KeyInternal::new(key)); + } + } + + let mut remapping = HashMap::new(); + self.inner.compact(|_, from, to| { + remapping.insert(from, to); + true + }); + + // At this point `key_map` contains a mapping for every element. + for internal_key in self.key_map.values_mut() { + if let Some(new_internal_key) = remapping.get(&internal_key.index) { + *internal_key = KeyInternal::new(*new_internal_key); + } + } + + if self.key_map.capacity() > 2 * self.key_map.len() { + self.key_map.shrink_to_fit(); + } + + self.compact_called = true; + } + + // Tries to re-map a `Key` that was given out to the user to its + // corresponding internal key. + fn remap_key(&self, key: &Key) -> Option<KeyInternal> { + let key_map = &self.key_map; + if self.compact_called { + key_map.get(&*key).copied() + } else { + Some((*key).into()) + } + } + + fn create_new_key(&mut self) -> KeyInternal { + while self.key_map.contains_key(&Key::new(self.next_key_index)) { + self.next_key_index = self.next_key_index.wrapping_add(1); + } + + KeyInternal::new(self.next_key_index) + } + + pub(crate) fn len(&self) -> usize { + self.inner.len() + } + + pub(crate) fn capacity(&self) -> usize { + self.inner.capacity() + } + + pub(crate) fn clear(&mut self) { + self.inner.clear(); + self.key_map.clear(); + self.compact_called = false; + } + + pub(crate) fn reserve(&mut self, additional: usize) { + self.inner.reserve(additional); + + if self.compact_called { + self.key_map.reserve(additional); + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub(crate) fn contains(&self, key: &Key) -> bool { + let remapped_key = self.remap_key(key); + + match remapped_key { + Some(internal_key) => self.inner.contains(internal_key.index), + None => false, + } + } +} + +impl<T> fmt::Debug for SlabStorage<T> +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + if fmt.alternate() { + fmt.debug_map().entries(self.inner.iter()).finish() + } else { + fmt.debug_struct("Slab") + .field("len", &self.len()) + .field("cap", &self.capacity()) + .finish() + } + } +} + +impl<T> Index<Key> for SlabStorage<T> { + type Output = Data<T>; + + fn index(&self, key: Key) -> &Self::Output { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +impl<T> IndexMut<Key> for SlabStorage<T> { + fn index_mut(&mut self, key: Key) -> &mut Data<T> { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &mut self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +/// An entry in `DelayQueue` that has expired and been removed. +/// +/// Values are returned by [`DelayQueue::poll_expired`]. +/// +/// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired +#[derive(Debug)] +pub struct Expired<T> { + /// The data stored in the queue + data: T, + + /// The expiration time + deadline: Instant, + + /// The key associated with the entry + key: Key, +} + +/// Token to a value stored in a `DelayQueue`. +/// +/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] +/// documentation for more details. +/// +/// [`DelayQueue`]: struct@DelayQueue +/// [`DelayQueue::insert`]: method@DelayQueue::insert +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Key { + index: usize, +} + +// Whereas `Key` is given out to users that use `DelayQueue`, internally we use +// `KeyInternal` as the key type in order to make the logic of mapping between keys +// as a result of `compact` calls clearer. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct KeyInternal { + index: usize, +} + +#[derive(Debug)] +struct Stack<T> { + /// Head of the stack + head: Option<Key>, + _p: PhantomData<fn() -> T>, +} + +#[derive(Debug)] +struct Data<T> { + /// The data being stored in the queue and will be returned at the requested + /// instant. + inner: T, + + /// The instant at which the item is returned. + when: u64, + + /// Set to true when stored in the `expired` queue + expired: bool, + + /// Next entry in the stack + next: Option<Key>, + + /// Previous entry in the stack + prev: Option<Key>, +} + +/// Maximum number of entries the queue can handle +const MAX_ENTRIES: usize = (1 << 30) - 1; + +impl<T> DelayQueue<T> { + /// Creates a new, empty, `DelayQueue`. + /// + /// The queue will not allocate storage until items are inserted into it. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// let delay_queue: DelayQueue<u32> = DelayQueue::new(); + /// ``` + pub fn new() -> DelayQueue<T> { + DelayQueue::with_capacity(0) + } + + /// Creates a new, empty, `DelayQueue` with the specified capacity. + /// + /// The queue will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the queue will not allocate for + /// storage. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// # use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// // These insertions are done without further allocation + /// for i in 0..10 { + /// delay_queue.insert(i, Duration::from_secs(i)); + /// } + /// + /// // This will make the queue allocate additional storage + /// delay_queue.insert(11, Duration::from_secs(11)); + /// # } + /// ``` + pub fn with_capacity(capacity: usize) -> DelayQueue<T> { + DelayQueue { + wheel: Wheel::new(), + slab: SlabStorage::with_capacity(capacity), + expired: Stack::default(), + delay: None, + wheel_now: 0, + start: Instant::now(), + waker: None, + } + } + + /// Inserts `value` into the queue set to expire at a specific instant in + /// time. + /// + /// This function is identical to `insert`, but takes an `Instant` instead + /// of a `Duration`. + /// + /// `value` is stored in the queue until `when` is reached. At which point, + /// `value` will be returned from [`poll_expired`]. If `when` has already been + /// reached, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an argument to + /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once + /// `value` is removed from the queue either by calling [`poll_expired`] after + /// `when` is reached or by calling [`remove`]. At this point, the caller + /// must take care to not use the returned [`Key`] again as it may reference + /// a different item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert_at( + /// "foo", Instant::now() + Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + pub fn insert_at(&mut self, value: T, when: Instant) -> Key { + assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + // Insert the value in the store + let key = self.slab.insert(Data { + inner: value, + when, + expired: false, + next: None, + prev: None, + }); + + self.insert_idx(when, key); + + // Set a new delay if the current's deadline is later than the one of the new item + let should_set_delay = if let Some(ref delay) = self.delay { + let current_exp = self.normalize_deadline(delay.deadline()); + current_exp > when + } else { + true + }; + + if should_set_delay { + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + let delay_time = self.start + Duration::from_millis(when); + if let Some(ref mut delay) = &mut self.delay { + delay.as_mut().reset(delay_time); + } else { + self.delay = Some(Box::pin(sleep_until(delay_time))); + } + } + + key + } + + /// Attempts to pull out the next value of the delay queue, registering the + /// current task for wakeup if the value is not yet available, and returning + /// `None` if the queue is exhausted. + pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Expired<T>>> { + if !self + .waker + .as_ref() + .map(|w| w.will_wake(cx.waker())) + .unwrap_or(false) + { + self.waker = Some(cx.waker().clone()); + } + + let item = ready!(self.poll_idx(cx)); + Poll::Ready(item.map(|key| { + let data = self.slab.remove(&key); + debug_assert!(data.next.is_none()); + debug_assert!(data.prev.is_none()); + + Expired { + key, + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + })) + } + + /// Inserts `value` into the queue set to expire after the requested duration + /// elapses. + /// + /// This function is identical to `insert_at`, but takes a `Duration` + /// instead of an `Instant`. + /// + /// `value` is stored in the queue until `timeout` duration has + /// elapsed after `insert` was called. At that point, `value` will + /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of + /// zero, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an + /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a + /// token and is reused once `value` is removed from the queue + /// either by calling [`poll_expired`] after `timeout` has elapsed + /// or by calling [`remove`]. At this point, the caller must not + /// use the returned [`Key`] again as it may reference a different + /// item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum + /// duration supported by the timer in the current `Runtime`. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + pub fn insert(&mut self, value: T, timeout: Duration) -> Key { + self.insert_at(value, Instant::now() + timeout) + } + + fn insert_idx(&mut self, when: u64, key: Key) { + use self::wheel::{InsertError, Stack}; + + // Register the deadline with the timer wheel + match self.wheel.insert(when, key, &mut self.slab) { + Ok(_) => {} + Err((_, InsertError::Elapsed)) => { + self.slab[key].expired = true; + // The delay is already expired, store it in the expired queue + self.expired.push(key, &mut self.slab); + } + Err((_, err)) => panic!("invalid deadline; err={:?}", err), + } + } + + /// Removes the key from the expired queue or the timer wheel + /// depending on its expiration status. + /// + /// # Panics + /// + /// Panics if the key is not contained in the expired queue or the wheel. + fn remove_key(&mut self, key: &Key) { + use crate::time::wheel::Stack; + + // Special case the `expired` queue + if self.slab[*key].expired { + self.expired.remove(key, &mut self.slab); + } else { + self.wheel.remove(key, &mut self.slab); + } + } + + /// Removes the item associated with `key` from the queue. + /// + /// There must be an item associated with `key`. The function returns the + /// removed item as well as the `Instant` at which it will the delay will + /// have expired. + /// + /// # Panics + /// + /// The function panics if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + pub fn remove(&mut self, key: &Key) -> Expired<T> { + let prev_deadline = self.next_deadline(); + + self.remove_key(key); + let data = self.slab.remove(key); + + let next_deadline = self.next_deadline(); + if prev_deadline != next_deadline { + match (next_deadline, &mut self.delay) { + (None, _) => self.delay = None, + (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline), + (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))), + } + } + + Expired { + key: Key::new(key.index), + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + } + + /// Sets the delay of the item associated with `key` to expire at `when`. + /// + /// This function is identical to `reset` but takes an `Instant` instead of + /// a `Duration`. + /// + /// The item remains in the queue but the delay is set to expire at `when`. + /// If `when` is in the past, then the item is immediately made available to + /// the caller. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future or if `key` is + /// not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); + /// + /// // "foo" is now scheduled to be returned in 10 seconds + /// # } + /// ``` + pub fn reset_at(&mut self, key: &Key, when: Instant) { + self.remove_key(key); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + self.slab[*key].when = when; + self.slab[*key].expired = false; + + self.insert_idx(when, *key); + + let next_deadline = self.next_deadline(); + if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { + // This should awaken us if necessary (ie, if already expired) + delay.as_mut().reset(deadline); + } + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation. + /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab + /// to the number of elements still contained in it, because elements cannot be moved to a different + /// index. To decrease the capacity to the size of the slab use [`compact`]. + /// + /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is + /// shrunk in place. Repeated calls run in O(1) though. + /// + /// [`compact`]: method@Self::compact + pub fn shrink_to_fit(&mut self) { + self.slab.shrink_to_fit(); + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation, + /// to the number of elements that are contained in it. + /// + /// This methods runs in O(n). + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// let key1 = delay_queue.insert(5, Duration::from_secs(5)); + /// let key2 = delay_queue.insert(10, Duration::from_secs(10)); + /// let key3 = delay_queue.insert(15, Duration::from_secs(15)); + /// + /// delay_queue.remove(&key2); + /// + /// delay_queue.compact(); + /// assert_eq!(delay_queue.capacity(), 2); + /// # } + /// ``` + pub fn compact(&mut self) { + self.slab.compact(); + } + + /// Returns the next time to poll as determined by the wheel + fn next_deadline(&mut self) -> Option<Instant> { + self.wheel + .poll_at() + .map(|poll_at| self.start + Duration::from_millis(poll_at)) + } + + /// Sets the delay of the item associated with `key` to expire after + /// `timeout`. + /// + /// This function is identical to `reset_at` but takes a `Duration` instead + /// of an `Instant`. + /// + /// The item remains in the queue but the delay is set to expire after + /// `timeout`. If `timeout` is zero, then the item is immediately made + /// available to the caller. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum supported + /// duration or if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset(&key, Duration::from_secs(10)); + /// + /// // "foo"is now scheduled to be returned in 10 seconds + /// # } + /// ``` + pub fn reset(&mut self, key: &Key, timeout: Duration) { + self.reset_at(key, Instant::now() + timeout); + } + + /// Clears the queue, removing all items. + /// + /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. + /// + /// Note that this method has no effect on the allocated capacity. + /// + /// [`poll_expired`]: method@Self::poll_expired + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// assert!(!delay_queue.is_empty()); + /// + /// delay_queue.clear(); + /// + /// assert!(delay_queue.is_empty()); + /// # } + /// ``` + pub fn clear(&mut self) { + self.slab.clear(); + self.expired = Stack::default(); + self.wheel = Wheel::new(); + self.delay = None; + } + + /// Returns the number of elements the queue can hold without reallocating. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// + /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.capacity(), 10); + /// ``` + pub fn capacity(&self) -> usize { + self.slab.capacity() + } + + /// Returns the number of elements currently in the queue. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.len(), 0); + /// delay_queue.insert(3, Duration::from_secs(5)); + /// assert_eq!(delay_queue.len(), 1); + /// # } + /// ``` + pub fn len(&self) -> usize { + self.slab.len() + } + + /// Reserves capacity for at least `additional` more items to be queued + /// without allocating. + /// + /// `reserve` does nothing if the queue already has sufficient capacity for + /// `additional` more values. If more capacity is required, a new segment of + /// memory will be allocated and all existing values will be copied into it. + /// As such, if the queue is already very large, a call to `reserve` can end + /// up being expensive. + /// + /// The queue may reserve more than `additional` extra space in order to + /// avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity exceeds the maximum number of entries the + /// queue can contain. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("hello", Duration::from_secs(10)); + /// delay_queue.reserve(10); + /// + /// assert!(delay_queue.capacity() >= 11); + /// # } + /// ``` + pub fn reserve(&mut self, additional: usize) { + self.slab.reserve(additional); + } + + /// Returns `true` if there are no items in the queue. + /// + /// Note that this function returns `false` even if all items have not yet + /// expired and a call to `poll` will return `Poll::Pending`. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// assert!(delay_queue.is_empty()); + /// + /// delay_queue.insert("hello", Duration::from_secs(5)); + /// assert!(!delay_queue.is_empty()); + /// # } + /// ``` + pub fn is_empty(&self) -> bool { + self.slab.is_empty() + } + + /// Polls the queue, returning the index of the next slot in the slab that + /// should be returned. + /// + /// A slot should be returned when the associated deadline has been reached. + fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Key>> { + use self::wheel::Stack; + + let expired = self.expired.pop(&mut self.slab); + + if expired.is_some() { + return Poll::Ready(expired); + } + + loop { + if let Some(ref mut delay) = self.delay { + if !delay.is_elapsed() { + ready!(Pin::new(&mut *delay).poll(cx)); + } + + let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); + + self.wheel_now = now; + } + + // We poll the wheel to get the next value out before finding the next deadline. + let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab); + + self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when))); + + if let Some(idx) = wheel_idx { + return Poll::Ready(Some(idx)); + } + + if self.delay.is_none() { + return Poll::Ready(None); + } + } + } + + fn normalize_deadline(&self, when: Instant) -> u64 { + let when = if when < self.start { + 0 + } else { + crate::time::ms(when - self.start, crate::time::Round::Up) + }; + + cmp::max(when, self.wheel.elapsed()) + } +} + +// We never put `T` in a `Pin`... +impl<T> Unpin for DelayQueue<T> {} + +impl<T> Default for DelayQueue<T> { + fn default() -> DelayQueue<T> { + DelayQueue::new() + } +} + +impl<T> futures_core::Stream for DelayQueue<T> { + // DelayQueue seems much more specific, where a user may care that it + // has reached capacity, so return those errors instead of panicking. + type Item = Expired<T>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { + DelayQueue::poll_expired(self.get_mut(), cx) + } +} + +impl<T> wheel::Stack for Stack<T> { + type Owned = Key; + type Borrowed = Key; + type Store = SlabStorage<T>; + + fn is_empty(&self) -> bool { + self.head.is_none() + } + + fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { + // Ensure the entry is not already in a stack. + debug_assert!(store[item].next.is_none()); + debug_assert!(store[item].prev.is_none()); + + // Remove the old head entry + let old = self.head.take(); + + if let Some(idx) = old { + store[idx].prev = Some(item); + } + + store[item].next = old; + self.head = Some(item); + } + + fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> { + if let Some(key) = self.head { + self.head = store[key].next; + + if let Some(idx) = self.head { + store[idx].prev = None; + } + + store[key].next = None; + debug_assert!(store[key].prev.is_none()); + + Some(key) + } else { + None + } + } + + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { + let key = *item; + assert!(store.contains(item)); + + // Ensure that the entry is in fact contained by the stack + debug_assert!({ + // This walks the full linked list even if an entry is found. + let mut next = self.head; + let mut contains = false; + + while let Some(idx) = next { + let data = &store[idx]; + + if idx == *item { + debug_assert!(!contains); + contains = true; + } + + next = data.next; + } + + contains + }); + + if let Some(next) = store[key].next { + store[next].prev = store[key].prev; + } + + if let Some(prev) = store[key].prev { + store[prev].next = store[key].next; + } else { + self.head = store[key].next; + } + + store[key].next = None; + store[key].prev = None; + } + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { + store[*item].when + } +} + +impl<T> Default for Stack<T> { + fn default() -> Stack<T> { + Stack { + head: None, + _p: PhantomData, + } + } +} + +impl Key { + pub(crate) fn new(index: usize) -> Key { + Key { index } + } +} + +impl KeyInternal { + pub(crate) fn new(index: usize) -> KeyInternal { + KeyInternal { index } + } +} + +impl From<Key> for KeyInternal { + fn from(item: Key) -> Self { + KeyInternal::new(item.index) + } +} + +impl From<KeyInternal> for Key { + fn from(item: KeyInternal) -> Self { + Key::new(item.index) + } +} + +impl<T> Expired<T> { + /// Returns a reference to the inner value. + pub fn get_ref(&self) -> &T { + &self.data + } + + /// Returns a mutable reference to the inner value. + pub fn get_mut(&mut self) -> &mut T { + &mut self.data + } + + /// Consumes `self` and returns the inner value. + pub fn into_inner(self) -> T { + self.data + } + + /// Returns the deadline that the expiration was set to. + pub fn deadline(&self) -> Instant { + self.deadline + } + + /// Returns the key that the expiration is indexed by. + pub fn key(&self) -> Key { + self.key + } +} diff --git a/third_party/rust/tokio-util/src/time/mod.rs b/third_party/rust/tokio-util/src/time/mod.rs new file mode 100644 index 0000000000..2d34008360 --- /dev/null +++ b/third_party/rust/tokio-util/src/time/mod.rs @@ -0,0 +1,47 @@ +//! Additional utilities for tracking time. +//! +//! This module provides additional utilities for executing code after a set period +//! of time. Currently there is only one: +//! +//! * `DelayQueue`: A queue where items are returned once the requested delay +//! has expired. +//! +//! This type must be used from within the context of the `Runtime`. + +use std::time::Duration; + +mod wheel; + +pub mod delay_queue; + +#[doc(inline)] +pub use delay_queue::DelayQueue; + +// ===== Internal utils ===== + +enum Round { + Up, + Down, +} + +/// Convert a `Duration` to milliseconds, rounding up and saturating at +/// `u64::MAX`. +/// +/// The saturating is fine because `u64::MAX` milliseconds are still many +/// million years. +#[inline] +fn ms(duration: Duration, round: Round) -> u64 { + const NANOS_PER_MILLI: u32 = 1_000_000; + const MILLIS_PER_SEC: u64 = 1_000; + + // Round up. + let millis = match round { + Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, + Round::Down => duration.subsec_millis(), + }; + + duration + .as_secs() + .saturating_mul(MILLIS_PER_SEC) + .saturating_add(u64::from(millis)) +} diff --git a/third_party/rust/tokio-util/src/time/wheel/level.rs b/third_party/rust/tokio-util/src/time/wheel/level.rs new file mode 100644 index 0000000000..8ea30af30f --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/level.rs @@ -0,0 +1,253 @@ +use crate::time::wheel::Stack; + +use std::fmt; + +/// Wheel for a single level in the timer. This wheel contains 64 slots. +pub(crate) struct Level<T> { + level: usize, + + /// Bit field tracking which slots currently contain entries. + /// + /// Using a bit field to track slots that contain entries allows avoiding a + /// scan to find entries. This field is updated when entries are added or + /// removed from a slot. + /// + /// The least-significant bit represents slot zero. + occupied: u64, + + /// Slots + slot: [T; LEVEL_MULT], +} + +/// Indicates when a slot must be processed next. +#[derive(Debug)] +pub(crate) struct Expiration { + /// The level containing the slot. + pub(crate) level: usize, + + /// The slot index. + pub(crate) slot: usize, + + /// The instant at which the slot needs to be processed. + pub(crate) deadline: u64, +} + +/// Level multiplier. +/// +/// Being a power of 2 is very important. +const LEVEL_MULT: usize = 64; + +impl<T: Stack> Level<T> { + pub(crate) fn new(level: usize) -> Level<T> { + // Rust's derived implementations for arrays require that the value + // contained by the array be `Copy`. So, here we have to manually + // initialize every single slot. + macro_rules! s { + () => { + T::default() + }; + } + + Level { + level, + occupied: 0, + slot: [ + // It does not look like the necessary traits are + // derived for [T; 64]. + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + ], + } + } + + /// Finds the slot that needs to be processed next and returns the slot and + /// `Instant` at which this slot must be processed. + pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> { + // Use the `occupied` bit field to get the index of the next slot that + // needs to be processed. + let slot = match self.next_occupied_slot(now) { + Some(slot) => slot, + None => return None, + }; + + // From the slot index, calculate the `Instant` at which it needs to be + // processed. This value *must* be in the future with respect to `now`. + + let level_range = level_range(self.level); + let slot_range = slot_range(self.level); + + // TODO: This can probably be simplified w/ power of 2 math + let level_start = now - (now % level_range); + let deadline = level_start + slot as u64 * slot_range; + + debug_assert!( + deadline >= now, + "deadline={}; now={}; level={}; slot={}; occupied={:b}", + deadline, + now, + self.level, + slot, + self.occupied + ); + + Some(Expiration { + level: self.level, + slot, + deadline, + }) + } + + fn next_occupied_slot(&self, now: u64) -> Option<usize> { + if self.occupied == 0 { + return None; + } + + // Get the slot for now using Maths + let now_slot = (now / slot_range(self.level)) as usize; + let occupied = self.occupied.rotate_right(now_slot as u32); + let zeros = occupied.trailing_zeros() as usize; + let slot = (zeros + now_slot) % 64; + + Some(slot) + } + + pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].push(item, store); + self.occupied |= occupied_bit(slot); + } + + pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].remove(item, store); + + if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + // Unset the bit + self.occupied ^= occupied_bit(slot); + } + } + + pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> { + let ret = self.slot[slot].pop(store); + + if ret.is_some() && self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + self.occupied ^= occupied_bit(slot); + } + + ret + } +} + +impl<T> fmt::Debug for Level<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Level") + .field("occupied", &self.occupied) + .finish() + } +} + +fn occupied_bit(slot: usize) -> u64 { + 1 << slot +} + +fn slot_range(level: usize) -> u64 { + LEVEL_MULT.pow(level as u32) as u64 +} + +fn level_range(level: usize) -> u64 { + LEVEL_MULT as u64 * slot_range(level) +} + +/// Convert a duration (milliseconds) and a level to a slot position +fn slot_for(duration: u64, level: usize) -> usize { + ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_slot_for() { + for pos in 0..64 { + assert_eq!(pos as usize, slot_for(pos, 0)); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!(pos as usize, slot_for(a as u64, level)); + } + } + } +} diff --git a/third_party/rust/tokio-util/src/time/wheel/mod.rs b/third_party/rust/tokio-util/src/time/wheel/mod.rs new file mode 100644 index 0000000000..4191e401df --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/mod.rs @@ -0,0 +1,314 @@ +mod level; +pub(crate) use self::level::Expiration; +use self::level::Level; + +mod stack; +pub(crate) use self::stack::Stack; + +use std::borrow::Borrow; +use std::fmt::Debug; +use std::usize; + +/// Timing wheel implementation. +/// +/// This type provides the hashed timing wheel implementation that backs `Timer` +/// and `DelayQueue`. +/// +/// The structure is generic over `T: Stack`. This allows handling timeout data +/// being stored on the heap or in a slab. In order to support the latter case, +/// the slab must be passed into each function allowing the implementation to +/// lookup timer entries. +/// +/// See `Timer` documentation for some implementation notes. +#[derive(Debug)] +pub(crate) struct Wheel<T> { + /// The number of milliseconds elapsed since the wheel started. + elapsed: u64, + + /// Timer wheel. + /// + /// Levels: + /// + /// * 1 ms slots / 64 ms range + /// * 64 ms slots / ~ 4 sec range + /// * ~ 4 sec slots / ~ 4 min range + /// * ~ 4 min slots / ~ 4 hr range + /// * ~ 4 hr slots / ~ 12 day range + /// * ~ 12 day slots / ~ 2 yr range + levels: Vec<Level<T>>, +} + +/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots +/// each, the timer is able to track time up to 2 years into the future with a +/// precision of 1 millisecond. +const NUM_LEVELS: usize = 6; + +/// The maximum duration of a delay +const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; + +#[derive(Debug)] +pub(crate) enum InsertError { + Elapsed, + Invalid, +} + +impl<T> Wheel<T> +where + T: Stack, +{ + /// Create a new timing wheel + pub(crate) fn new() -> Wheel<T> { + let levels = (0..NUM_LEVELS).map(Level::new).collect(); + + Wheel { elapsed: 0, levels } + } + + /// Return the number of milliseconds that have elapsed since the timing + /// wheel's creation. + pub(crate) fn elapsed(&self) -> u64 { + self.elapsed + } + + /// Insert an entry into the timing wheel. + /// + /// # Arguments + /// + /// * `when`: is the instant at which the entry should be fired. It is + /// represented as the number of milliseconds since the creation + /// of the timing wheel. + /// + /// * `item`: The item to insert into the wheel. + /// + /// * `store`: The slab or `()` when using heap storage. + /// + /// # Return + /// + /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. + /// + /// `Err(Elapsed)` indicates that `when` represents an instant that has + /// already passed. In this case, the caller should fire the timeout + /// immediately. + /// + /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. + pub(crate) fn insert( + &mut self, + when: u64, + item: T::Owned, + store: &mut T::Store, + ) -> Result<(), (T::Owned, InsertError)> { + if when <= self.elapsed { + return Err((item, InsertError::Elapsed)); + } else if when - self.elapsed > MAX_DURATION { + return Err((item, InsertError::Invalid)); + } + + // Get the level at which the entry should be stored + let level = self.level_for(when); + + self.levels[level].add_entry(when, item, store); + + debug_assert!({ + self.levels[level] + .next_expiration(self.elapsed) + .map(|e| e.deadline >= self.elapsed) + .unwrap_or(true) + }); + + Ok(()) + } + + /// Remove `item` from the timing wheel. + pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { + let when = T::when(item, store); + + assert!( + self.elapsed <= when, + "elapsed={}; when={}", + self.elapsed, + when + ); + + let level = self.level_for(when); + + self.levels[level].remove_entry(when, item, store); + } + + /// Instant at which to poll + pub(crate) fn poll_at(&self) -> Option<u64> { + self.next_expiration().map(|expiration| expiration.deadline) + } + + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> { + loop { + let expiration = self.next_expiration().and_then(|expiration| { + if expiration.deadline > now { + None + } else { + Some(expiration) + } + }); + + match expiration { + Some(ref expiration) => { + if let Some(item) = self.poll_expiration(expiration, store) { + return Some(item); + } + + self.set_elapsed(expiration.deadline); + } + None => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + return None; + } + } + } + } + + /// Returns the instant at which the next timeout expires. + fn next_expiration(&self) -> Option<Expiration> { + // Check all levels + for level in 0..NUM_LEVELS { + if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { + // There cannot be any expirations at a higher level that happen + // before this one. + debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); + + return Some(expiration); + } + } + + None + } + + /// Used for debug assertions + fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { + let mut res = true; + + for l2 in start_level..NUM_LEVELS { + if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { + if e2.deadline < before { + res = false; + } + } + } + + res + } + + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// return it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn poll_expiration( + &mut self, + expiration: &Expiration, + store: &mut T::Store, + ) -> Option<T::Owned> { + while let Some(item) = self.pop_entry(expiration, store) { + if expiration.level == 0 { + debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); + + return Some(item); + } else { + let when = T::when(item.borrow(), store); + + let next_level = expiration.level - 1; + + self.levels[next_level].add_entry(when, item, store); + } + } + + None + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } + } + + fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> { + self.levels[expiration.level].pop_entry_slot(expiration.slot, store) + } + + fn level_for(&self, when: u64) -> usize { + level_for(self.elapsed, when) + } +} + +fn level_for(elapsed: u64, when: u64) -> usize { + const SLOT_MASK: u64 = (1 << 6) - 1; + + // Mask in the trailing bits ignored by the level calculation in order to cap + // the possible leading zeros + let masked = elapsed ^ when | SLOT_MASK; + + let leading_zeros = masked.leading_zeros() as usize; + let significant = 63 - leading_zeros; + significant / 6 +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_level_for() { + for pos in 0..64 { + assert_eq!( + 0, + level_for(0, pos), + "level_for({}) -- binary = {:b}", + pos, + pos + ); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + + if pos > level { + let a = a - 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + + if pos < 64 { + let a = a + 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + } + } + } +} diff --git a/third_party/rust/tokio-util/src/time/wheel/stack.rs b/third_party/rust/tokio-util/src/time/wheel/stack.rs new file mode 100644 index 0000000000..c87adcafda --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/stack.rs @@ -0,0 +1,28 @@ +use std::borrow::Borrow; +use std::cmp::Eq; +use std::hash::Hash; + +/// Abstracts the stack operations needed to track timeouts. +pub(crate) trait Stack: Default { + /// Type of the item stored in the stack + type Owned: Borrow<Self::Borrowed>; + + /// Borrowed item + type Borrowed: Eq + Hash; + + /// Item storage, this allows a slab to be used instead of just the heap + type Store; + + /// Returns `true` if the stack is empty + fn is_empty(&self) -> bool; + + /// Push an item onto the stack + fn push(&mut self, item: Self::Owned, store: &mut Self::Store); + + /// Pop an item from the stack + fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>; + + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; +} diff --git a/third_party/rust/tokio-util/src/udp/frame.rs b/third_party/rust/tokio-util/src/udp/frame.rs new file mode 100644 index 0000000000..d900fd7691 --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/frame.rs @@ -0,0 +1,245 @@ +use crate::codec::{Decoder, Encoder}; + +use futures_core::Stream; +use tokio::{io::ReadBuf, net::UdpSocket}; + +use bytes::{BufMut, BytesMut}; +use futures_core::ready; +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{ + borrow::Borrow, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; +use std::{io, mem::MaybeUninit}; + +/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using +/// the `Encoder` and `Decoder` traits to encode and decode frames. +/// +/// Raw UDP sockets work with datagrams, but higher-level code usually wants to +/// batch these into meaningful chunks, called "frames". This method layers +/// framing on top of this socket by using the `Encoder` and `Decoder` traits to +/// handle encoding and decoding of messages frames. Note that the incoming and +/// outgoing frame types may be distinct. +/// +/// This function returns a *single* object that is both [`Stream`] and [`Sink`]; +/// grouping this into a single object is often useful for layering things which +/// require both read and write access to the underlying object. +/// +/// If you want to work more directly with the streams and sink, consider +/// calling [`split`] on the `UdpFramed` returned by this method, which will break +/// them into separate objects, allowing them to interact more easily. +/// +/// [`Stream`]: futures_core::Stream +/// [`Sink`]: futures_sink::Sink +/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split +#[must_use = "sinks do nothing unless polled"] +#[derive(Debug)] +pub struct UdpFramed<C, T = UdpSocket> { + socket: T, + codec: C, + rd: BytesMut, + wr: BytesMut, + out_addr: SocketAddr, + flushed: bool, + is_readable: bool, + current_addr: Option<SocketAddr>, +} + +const INITIAL_RD_CAPACITY: usize = 64 * 1024; +const INITIAL_WR_CAPACITY: usize = 8 * 1024; + +impl<C, T> Unpin for UdpFramed<C, T> {} + +impl<C, T> Stream for UdpFramed<C, T> +where + T: Borrow<UdpSocket>, + C: Decoder, +{ + type Item = Result<(C::Item, SocketAddr), C::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let pin = self.get_mut(); + + pin.rd.reserve(INITIAL_RD_CAPACITY); + + loop { + // Are there still bytes left in the read buffer to decode? + if pin.is_readable { + if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { + let current_addr = pin + .current_addr + .expect("will always be set before this line is called"); + + return Poll::Ready(Some(Ok((frame, current_addr)))); + } + + // if this line has been reached then decode has returned `None`. + pin.is_readable = false; + pin.rd.clear(); + } + + // We're out of data. Try and fetch more data to decode + let addr = unsafe { + // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be + // writing to it via `poll_recv_from` and therefore initializing the memory. + let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]); + let mut read = ReadBuf::uninit(buf); + let ptr = read.filled().as_ptr(); + let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); + + assert_eq!(ptr, read.filled().as_ptr()); + let addr = res?; + pin.rd.advance_mut(read.filled().len()); + addr + }; + + pin.current_addr = Some(addr); + pin.is_readable = true; + } + } +} + +impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> +where + T: Borrow<UdpSocket>, + C: Encoder<I>, +{ + type Error = C::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if !self.flushed { + match self.poll_flush(cx)? { + Poll::Ready(()) => {} + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { + let (frame, out_addr) = item; + + let pin = self.get_mut(); + + pin.codec.encode(frame, &mut pin.wr)?; + pin.out_addr = out_addr; + pin.flushed = false; + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if self.flushed { + return Poll::Ready(Ok(())); + } + + let Self { + ref socket, + ref mut out_addr, + ref mut wr, + .. + } = *self; + + let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; + + let wrote_all = n == self.wr.len(); + self.wr.clear(); + self.flushed = true; + + let res = if wrote_all { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "failed to write entire datagram to socket", + ) + .into()) + }; + + Poll::Ready(res) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} + +impl<C, T> UdpFramed<C, T> +where + T: Borrow<UdpSocket>, +{ + /// Create a new `UdpFramed` backed by the given socket and codec. + /// + /// See struct level documentation for more details. + pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { + Self { + socket, + codec, + out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), + rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), + wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), + flushed: true, + is_readable: false, + current_addr: None, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_ref(&self) -> &T { + &self.socket + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.socket + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &C { + &self.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `UdpFramed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut C { + &mut self.codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.rd + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.rd + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + pub fn into_inner(self) -> T { + self.socket + } +} diff --git a/third_party/rust/tokio-util/src/udp/mod.rs b/third_party/rust/tokio-util/src/udp/mod.rs new file mode 100644 index 0000000000..f88ea030aa --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/mod.rs @@ -0,0 +1,4 @@ +//! UDP framing + +mod frame; +pub use frame::UdpFramed; |