summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio-util/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio-util/src')
-rw-r--r--third_party/rust/tokio-util/src/cfg.rs71
-rw-r--r--third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs263
-rw-r--r--third_party/rust/tokio-util/src/codec/bytes_codec.rs86
-rw-r--r--third_party/rust/tokio-util/src/codec/decoder.rs184
-rw-r--r--third_party/rust/tokio-util/src/codec/encoder.rs25
-rw-r--r--third_party/rust/tokio-util/src/codec/framed.rs373
-rw-r--r--third_party/rust/tokio-util/src/codec/framed_impl.rs308
-rw-r--r--third_party/rust/tokio-util/src/codec/framed_read.rs199
-rw-r--r--third_party/rust/tokio-util/src/codec/framed_write.rs178
-rw-r--r--third_party/rust/tokio-util/src/codec/length_delimited.rs1047
-rw-r--r--third_party/rust/tokio-util/src/codec/lines_codec.rs230
-rw-r--r--third_party/rust/tokio-util/src/codec/mod.rs290
-rw-r--r--third_party/rust/tokio-util/src/compat.rs274
-rw-r--r--third_party/rust/tokio-util/src/context.rs190
-rw-r--r--third_party/rust/tokio-util/src/either.rs188
-rw-r--r--third_party/rust/tokio-util/src/io/mod.rs24
-rw-r--r--third_party/rust/tokio-util/src/io/read_buf.rs65
-rw-r--r--third_party/rust/tokio-util/src/io/reader_stream.rs118
-rw-r--r--third_party/rust/tokio-util/src/io/stream_reader.rs203
-rw-r--r--third_party/rust/tokio-util/src/io/sync_bridge.rs103
-rw-r--r--third_party/rust/tokio-util/src/lib.rs201
-rw-r--r--third_party/rust/tokio-util/src/loom.rs1
-rw-r--r--third_party/rust/tokio-util/src/net/mod.rs97
-rw-r--r--third_party/rust/tokio-util/src/net/unix/mod.rs18
-rw-r--r--third_party/rust/tokio-util/src/sync/cancellation_token.rs224
-rw-r--r--third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs27
-rw-r--r--third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs373
-rw-r--r--third_party/rust/tokio-util/src/sync/mod.rs13
-rw-r--r--third_party/rust/tokio-util/src/sync/mpsc.rs283
-rw-r--r--third_party/rust/tokio-util/src/sync/poll_semaphore.rs136
-rw-r--r--third_party/rust/tokio-util/src/sync/reusable_box.rs148
-rw-r--r--third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs155
-rw-r--r--third_party/rust/tokio-util/src/sync/tests/mod.rs1
-rw-r--r--third_party/rust/tokio-util/src/task/mod.rs4
-rw-r--r--third_party/rust/tokio-util/src/task/spawn_pinned.rs307
-rw-r--r--third_party/rust/tokio-util/src/time/delay_queue.rs1221
-rw-r--r--third_party/rust/tokio-util/src/time/mod.rs47
-rw-r--r--third_party/rust/tokio-util/src/time/wheel/level.rs253
-rw-r--r--third_party/rust/tokio-util/src/time/wheel/mod.rs314
-rw-r--r--third_party/rust/tokio-util/src/time/wheel/stack.rs28
-rw-r--r--third_party/rust/tokio-util/src/udp/frame.rs245
-rw-r--r--third_party/rust/tokio-util/src/udp/mod.rs4
42 files changed, 8519 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/src/cfg.rs b/third_party/rust/tokio-util/src/cfg.rs
new file mode 100644
index 0000000000..4035255aff
--- /dev/null
+++ b/third_party/rust/tokio-util/src/cfg.rs
@@ -0,0 +1,71 @@
+macro_rules! cfg_codec {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "codec")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "codec")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_compat {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "compat")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "compat")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_net {
+ ($($item:item)*) => {
+ $(
+ #[cfg(all(feature = "net", feature = "codec"))]
+ #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_io {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "io")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "io")))]
+ $item
+ )*
+ }
+}
+
+cfg_io! {
+ macro_rules! cfg_io_util {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "io-util")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
+ $item
+ )*
+ }
+ }
+}
+
+macro_rules! cfg_rt {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_time {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "time")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "time")))]
+ $item
+ )*
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs
new file mode 100644
index 0000000000..3dbfd456b0
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs
@@ -0,0 +1,263 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::{cmp, fmt, io, str, usize};
+
+const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r";
+const DEFAULT_SEQUENCE_WRITER: &[u8] = b",";
+/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+///
+/// # Example
+/// Decode string of bytes containing various different delimiters.
+///
+/// [`BytesMut`]: bytes::BytesMut
+/// [`Error`]: std::io::Error
+///
+/// ```
+/// use tokio_util::codec::{AnyDelimiterCodec, Decoder};
+/// use bytes::{BufMut, BytesMut};
+///
+/// #
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> Result<(), std::io::Error> {
+/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec());
+/// let buf = &mut BytesMut::new();
+/// buf.reserve(200);
+/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r");
+/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!(None, codec.decode(buf).unwrap());
+/// # Ok(())
+/// # }
+/// ```
+///
+#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub struct AnyDelimiterCodec {
+ // Stored index of the next index to examine for the delimiter character.
+ // This is used to optimize searching.
+ // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`,
+ // because that is the next index to examine.
+ // The next time `decode` is called with `abcde}`, the method will
+ // only look at `de}` before returning.
+ next_index: usize,
+
+ /// The maximum length for a given chunk. If `usize::MAX`, chunks will be
+ /// read until a delimiter character is reached.
+ max_length: usize,
+
+ /// Are we currently discarding the remainder of a chunk which was over
+ /// the length limit?
+ is_discarding: bool,
+
+ /// The bytes that are using for search during decode
+ seek_delimiters: Vec<u8>,
+
+ /// The bytes that are using for encoding
+ sequence_writer: Vec<u8>,
+}
+
+impl AnyDelimiterCodec {
+ /// Returns a `AnyDelimiterCodec` for splitting up data into chunks.
+ ///
+ /// # Note
+ ///
+ /// The returned `AnyDelimiterCodec` will not have an upper bound on the length
+ /// of a buffered chunk. See the documentation for [`new_with_max_length`]
+ /// for information on why this could be a potential security risk.
+ ///
+ /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length()
+ pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec {
+ AnyDelimiterCodec {
+ next_index: 0,
+ max_length: usize::MAX,
+ is_discarding: false,
+ seek_delimiters,
+ sequence_writer,
+ }
+ }
+
+ /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit.
+ ///
+ /// If this is set, calls to `AnyDelimiterCodec::decode` will return a
+ /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls
+ /// will discard up to `limit` bytes from that chunk until a delimiter
+ /// character is reached, returning `None` until the delimiter over the limit
+ /// has been fully discarded. After that point, calls to `decode` will
+ /// function as normal.
+ ///
+ /// # Note
+ ///
+ /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which
+ /// will be exposed to untrusted input. Otherwise, the size of the buffer
+ /// that holds the chunk currently being read is unbounded. An attacker could
+ /// exploit this unbounded buffer by sending an unbounded amount of input
+ /// without any delimiter characters, causing unbounded memory consumption.
+ ///
+ /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError
+ pub fn new_with_max_length(
+ seek_delimiters: Vec<u8>,
+ sequence_writer: Vec<u8>,
+ max_length: usize,
+ ) -> Self {
+ AnyDelimiterCodec {
+ max_length,
+ ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer)
+ }
+ }
+
+ /// Returns the maximum chunk length when decoding.
+ ///
+ /// ```
+ /// use std::usize;
+ /// use tokio_util::codec::AnyDelimiterCodec;
+ ///
+ /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec());
+ /// assert_eq!(codec.max_length(), usize::MAX);
+ /// ```
+ /// ```
+ /// use tokio_util::codec::AnyDelimiterCodec;
+ ///
+ /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256);
+ /// assert_eq!(codec.max_length(), 256);
+ /// ```
+ pub fn max_length(&self) -> usize {
+ self.max_length
+ }
+}
+
+impl Decoder for AnyDelimiterCodec {
+ type Item = Bytes;
+ type Error = AnyDelimiterCodecError;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
+ loop {
+ // Determine how far into the buffer we'll search for a delimiter. If
+ // there's no max_length set, we'll read to the end of the buffer.
+ let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
+
+ let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| {
+ self.seek_delimiters
+ .iter()
+ .any(|delimiter| *b == *delimiter)
+ });
+
+ match (self.is_discarding, new_chunk_offset) {
+ (true, Some(offset)) => {
+ // If we found a new chunk, discard up to that offset and
+ // then stop discarding. On the next iteration, we'll try
+ // to read a chunk normally.
+ buf.advance(offset + self.next_index + 1);
+ self.is_discarding = false;
+ self.next_index = 0;
+ }
+ (true, None) => {
+ // Otherwise, we didn't find a new chunk, so we'll discard
+ // everything we read. On the next iteration, we'll continue
+ // discarding up to max_len bytes unless we find a new chunk.
+ buf.advance(read_to);
+ self.next_index = 0;
+ if buf.is_empty() {
+ return Ok(None);
+ }
+ }
+ (false, Some(offset)) => {
+ // Found a chunk!
+ let new_chunk_index = offset + self.next_index;
+ self.next_index = 0;
+ let mut chunk = buf.split_to(new_chunk_index + 1);
+ chunk.truncate(chunk.len() - 1);
+ let chunk = chunk.freeze();
+ return Ok(Some(chunk));
+ }
+ (false, None) if buf.len() > self.max_length => {
+ // Reached the maximum length without finding a
+ // new chunk, return an error and start discarding on the
+ // next call.
+ self.is_discarding = true;
+ return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded);
+ }
+ (false, None) => {
+ // We didn't find a chunk or reach the length limit, so the next
+ // call will resume searching at the current offset.
+ self.next_index = read_to;
+ return Ok(None);
+ }
+ }
+ }
+ }
+
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
+ Ok(match self.decode(buf)? {
+ Some(frame) => Some(frame),
+ None => {
+ // return remaining data, if any
+ if buf.is_empty() {
+ None
+ } else {
+ let chunk = buf.split_to(buf.len());
+ self.next_index = 0;
+ Some(chunk.freeze())
+ }
+ }
+ })
+ }
+}
+
+impl<T> Encoder<T> for AnyDelimiterCodec
+where
+ T: AsRef<str>,
+{
+ type Error = AnyDelimiterCodecError;
+
+ fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> {
+ let chunk = chunk.as_ref();
+ buf.reserve(chunk.len() + 1);
+ buf.put(chunk.as_bytes());
+ buf.put(self.sequence_writer.as_ref());
+
+ Ok(())
+ }
+}
+
+impl Default for AnyDelimiterCodec {
+ fn default() -> Self {
+ Self::new(
+ DEFAULT_SEEK_DELIMITERS.to_vec(),
+ DEFAULT_SEQUENCE_WRITER.to_vec(),
+ )
+ }
+}
+
+/// An error occurred while encoding or decoding a chunk.
+#[derive(Debug)]
+pub enum AnyDelimiterCodecError {
+ /// The maximum chunk length was exceeded.
+ MaxChunkLengthExceeded,
+ /// An IO error occurred.
+ Io(io::Error),
+}
+
+impl fmt::Display for AnyDelimiterCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ AnyDelimiterCodecError::MaxChunkLengthExceeded => {
+ write!(f, "max chunk length exceeded")
+ }
+ AnyDelimiterCodecError::Io(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+impl From<io::Error> for AnyDelimiterCodecError {
+ fn from(e: io::Error) -> AnyDelimiterCodecError {
+ AnyDelimiterCodecError::Io(e)
+ }
+}
+
+impl std::error::Error for AnyDelimiterCodecError {}
diff --git a/third_party/rust/tokio-util/src/codec/bytes_codec.rs b/third_party/rust/tokio-util/src/codec/bytes_codec.rs
new file mode 100644
index 0000000000..ceab228b94
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/bytes_codec.rs
@@ -0,0 +1,86 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{BufMut, Bytes, BytesMut};
+use std::io;
+
+/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+///
+/// # Example
+///
+/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`.
+///
+/// [`AsyncRead`]: tokio::io::AsyncRead
+/// [`BytesMut`]: bytes::BytesMut
+/// [`Error`]: std::io::Error
+///
+/// ```
+/// # mod hidden {
+/// # #[allow(unused_imports)]
+/// use tokio::fs::File;
+/// # }
+/// use tokio::io::AsyncRead;
+/// use tokio_util::codec::{FramedRead, BytesCodec};
+///
+/// # enum File {}
+/// # impl File {
+/// # async fn open(_name: &str) -> Result<impl AsyncRead, std::io::Error> {
+/// # use std::io::Cursor;
+/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5]))
+/// # }
+/// # }
+/// #
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> Result<(), std::io::Error> {
+/// let my_async_read = File::open("filename.txt").await?;
+/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new());
+/// # Ok(())
+/// # }
+/// ```
+///
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
+pub struct BytesCodec(());
+
+impl BytesCodec {
+ /// Creates a new `BytesCodec` for shipping around raw bytes.
+ pub fn new() -> BytesCodec {
+ BytesCodec(())
+ }
+}
+
+impl Decoder for BytesCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
+ if !buf.is_empty() {
+ let len = buf.len();
+ Ok(Some(buf.split_to(len)))
+ } else {
+ Ok(None)
+ }
+ }
+}
+
+impl Encoder<Bytes> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
+ buf.reserve(data.len());
+ buf.put(data);
+ Ok(())
+ }
+}
+
+impl Encoder<BytesMut> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
+ buf.reserve(data.len());
+ buf.put(data);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/decoder.rs b/third_party/rust/tokio-util/src/codec/decoder.rs
new file mode 100644
index 0000000000..c5927783d1
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/decoder.rs
@@ -0,0 +1,184 @@
+use crate::codec::Framed;
+
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use std::io;
+
+/// Decoding of frames via buffers.
+///
+/// This trait is used when constructing an instance of [`Framed`] or
+/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has
+/// already been buffered in `src` and decodes the data into a stream of
+/// `Self::Item` frames.
+///
+/// Implementations are able to track state on `self`, which enables
+/// implementing stateful streaming parsers. In many cases, though, this type
+/// will simply be a unit struct (e.g. `struct HttpDecoder`).
+///
+/// For some underlying data-sources, namely files and FIFOs,
+/// it's possible to temporarily read 0 bytes by reaching EOF.
+///
+/// In these cases `decode_eof` will be called until it signals
+/// fullfillment of all closing frames by returning `Ok(None)`.
+/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`]
+/// will not invoke `decode` or `decode_eof` again, until data can be read
+/// during a retry.
+///
+/// It is up to the Decoder to keep track of a restart after an EOF,
+/// and to decide how to handle such an event by, for example,
+/// allowing frames to cross EOF boundaries, re-emitting opening frames, or
+/// resetting the entire internal state.
+///
+/// [`Framed`]: crate::codec::Framed
+/// [`FramedRead`]: crate::codec::FramedRead
+pub trait Decoder {
+ /// The type of decoded frames.
+ type Item;
+
+ /// The type of unrecoverable frame decoding errors.
+ ///
+ /// If an individual message is ill-formed but can be ignored without
+ /// interfering with the processing of future messages, it may be more
+ /// useful to report the failure as an `Item`.
+ ///
+ /// `From<io::Error>` is required in the interest of making `Error` suitable
+ /// for returning directly from a [`FramedRead`], and to enable the default
+ /// implementation of `decode_eof` to yield an `io::Error` when the decoder
+ /// fails to consume all available data.
+ ///
+ /// Note that implementors of this trait can simply indicate `type Error =
+ /// io::Error` to use I/O errors as this type.
+ ///
+ /// [`FramedRead`]: crate::codec::FramedRead
+ type Error: From<io::Error>;
+
+ /// Attempts to decode a frame from the provided buffer of bytes.
+ ///
+ /// This method is called by [`FramedRead`] whenever bytes are ready to be
+ /// parsed. The provided buffer of bytes is what's been read so far, and
+ /// this instance of `Decode` can determine whether an entire frame is in
+ /// the buffer and is ready to be returned.
+ ///
+ /// If an entire frame is available, then this instance will remove those
+ /// bytes from the buffer provided and return them as a decoded
+ /// frame. Note that removing bytes from the provided buffer doesn't always
+ /// necessarily copy the bytes, so this should be an efficient operation in
+ /// most circumstances.
+ ///
+ /// If the bytes look valid, but a frame isn't fully available yet, then
+ /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that
+ /// it needs to read some more bytes before calling this method again.
+ ///
+ /// Note that the bytes provided may be empty. If a previous call to
+ /// `decode` consumed all the bytes in the buffer then `decode` will be
+ /// called again until it returns `Ok(None)`, indicating that more bytes need to
+ /// be read.
+ ///
+ /// Finally, if the bytes in the buffer are malformed then an error is
+ /// returned indicating why. This informs [`Framed`] that the stream is now
+ /// corrupt and should be terminated.
+ ///
+ /// [`Framed`]: crate::codec::Framed
+ /// [`FramedRead`]: crate::codec::FramedRead
+ ///
+ /// # Buffer management
+ ///
+ /// Before returning from the function, implementations should ensure that
+ /// the buffer has appropriate capacity in anticipation of future calls to
+ /// `decode`. Failing to do so leads to inefficiency.
+ ///
+ /// For example, if frames have a fixed length, or if the length of the
+ /// current frame is known from a header, a possible buffer management
+ /// strategy is:
+ ///
+ /// ```no_run
+ /// # use std::io;
+ /// #
+ /// # use bytes::BytesMut;
+ /// # use tokio_util::codec::Decoder;
+ /// #
+ /// # struct MyCodec;
+ /// #
+ /// impl Decoder for MyCodec {
+ /// // ...
+ /// # type Item = BytesMut;
+ /// # type Error = io::Error;
+ ///
+ /// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
+ /// // ...
+ ///
+ /// // Reserve enough to complete decoding of the current frame.
+ /// let current_frame_len: usize = 1000; // Example.
+ /// // And to start decoding the next frame.
+ /// let next_frame_header_len: usize = 10; // Example.
+ /// src.reserve(current_frame_len + next_frame_header_len);
+ ///
+ /// return Ok(None);
+ /// }
+ /// }
+ /// ```
+ ///
+ /// An optimal buffer management strategy minimizes reallocations and
+ /// over-allocations.
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
+
+ /// A default method available to be called when there are no more bytes
+ /// available to be read from the underlying I/O.
+ ///
+ /// This method defaults to calling `decode` and returns an error if
+ /// `Ok(None)` is returned while there is unconsumed data in `buf`.
+ /// Typically this doesn't need to be implemented unless the framing
+ /// protocol differs near the end of the stream, or if you need to construct
+ /// frames _across_ eof boundaries on sources that can be resumed.
+ ///
+ /// Note that the `buf` argument may be empty. If a previous call to
+ /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be
+ /// called again until it returns `None`, indicating that there are no more
+ /// frames to yield. This behavior enables returning finalization frames
+ /// that may not be based on inbound data.
+ ///
+ /// Once `None` has been returned, `decode_eof` won't be called again until
+ /// an attempt to resume the stream has been made, where the underlying stream
+ /// actually returned more data.
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
+ match self.decode(buf)? {
+ Some(frame) => Ok(Some(frame)),
+ None => {
+ if buf.is_empty() {
+ Ok(None)
+ } else {
+ Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into())
+ }
+ }
+ }
+ }
+
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// `Io` object, using `Decode` and `Encode` to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the `Codec`
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both `Stream` and
+ /// `Sink`; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling `split` on the [`Framed`] returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Framed`]: crate::codec::Framed
+ fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self>
+ where
+ Self: Sized,
+ {
+ Framed::new(io, self)
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/encoder.rs b/third_party/rust/tokio-util/src/codec/encoder.rs
new file mode 100644
index 0000000000..770a10fa9b
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/encoder.rs
@@ -0,0 +1,25 @@
+use bytes::BytesMut;
+use std::io;
+
+/// Trait of helper objects to write out messages as bytes, for use with
+/// [`FramedWrite`].
+///
+/// [`FramedWrite`]: crate::codec::FramedWrite
+pub trait Encoder<Item> {
+ /// The type of encoding errors.
+ ///
+ /// [`FramedWrite`] requires `Encoder`s errors to implement `From<io::Error>`
+ /// in the interest letting it return `Error`s directly.
+ ///
+ /// [`FramedWrite`]: crate::codec::FramedWrite
+ type Error: From<io::Error>;
+
+ /// Encodes a frame into the buffer provided.
+ ///
+ /// This method will encode `item` into the byte buffer provided by `dst`.
+ /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and
+ /// will be written out when possible.
+ ///
+ /// [`FramedWrite`]: crate::codec::FramedWrite
+ fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>;
+}
diff --git a/third_party/rust/tokio-util/src/codec/framed.rs b/third_party/rust/tokio-util/src/codec/framed.rs
new file mode 100644
index 0000000000..d89b8b6dc3
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/framed.rs
@@ -0,0 +1,373 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
+
+use futures_core::Stream;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using
+ /// the `Encoder` and `Decoder` traits to encode and decode frames.
+ ///
+ /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or
+ /// by using the `new` function seen below.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Decoder::framed`]: crate::codec::Decoder::framed()
+ pub struct Framed<T, U> {
+ #[pin]
+ inner: FramedImpl<T, U, RWFrames>
+ }
+}
+
+impl<T, U> Framed<T, U>
+where
+ T: AsyncRead + AsyncWrite,
+{
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the codec
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// Note that, for some byte sources, the stream can be resumed after an EOF
+ /// by reading from it, even after it has returned `None`. Repeated attempts
+ /// to do so, without new data available, continue to return `None` without
+ /// creating more (closing) frames.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decode`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn new(inner: T, codec: U) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner,
+ codec,
+ state: Default::default(),
+ },
+ }
+ }
+
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data,
+ /// with a specific read buffer initial capacity.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the codec
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decode`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner,
+ codec,
+ state: RWFrames {
+ read: ReadFrame {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(capacity),
+ has_errored: false,
+ },
+ write: WriteFrame::default(),
+ },
+ },
+ }
+ }
+}
+
+impl<T, U> Framed<T, U> {
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the `Codec`
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// This objects takes a stream and a readbuffer and a writebuffer. These field
+ /// can be obtained from an existing `Framed` with the [`into_parts`] method.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decoder`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`into_parts`]: crate::codec::Framed::into_parts()
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner: parts.io,
+ codec: parts.codec,
+ state: RWFrames {
+ read: parts.read_buf.into(),
+ write: parts.write_buf.into(),
+ },
+ },
+ }
+ }
+
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Returns a reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec(&self) -> &U {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_mut(&mut self) -> &mut U {
+ &mut self.inner.codec
+ }
+
+ /// Maps the codec `U` to `C`, preserving the read and write buffers
+ /// wrapped by `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn map_codec<C, F>(self, map: F) -> Framed<T, C>
+ where
+ F: FnOnce(U) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let parts = self.into_parts();
+ Framed::from_parts(FramedParts {
+ io: parts.io,
+ codec: map(parts.codec),
+ read_buf: parts.read_buf,
+ write_buf: parts.write_buf,
+ _priv: (),
+ })
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.inner.state.read.buffer
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.read.buffer
+ }
+
+ /// Returns a reference to the write buffer.
+ pub fn write_buffer(&self) -> &BytesMut {
+ &self.inner.state.write.buffer
+ }
+
+ /// Returns a mutable reference to the write buffer.
+ pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.write.buffer
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream, the buffer
+ /// with unprocessed data, and the codec.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_parts(self) -> FramedParts<T, U> {
+ FramedParts {
+ io: self.inner.inner,
+ codec: self.inner.codec,
+ read_buf: self.inner.state.read.buffer,
+ write_buf: self.inner.state.write.buffer,
+ _priv: (),
+ }
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, U> Stream for Framed<T, U>
+where
+ T: AsyncRead,
+ U: Decoder,
+{
+ type Item = Result<U::Item, U::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, I, U> Sink<I> for Framed<T, U>
+where
+ T: AsyncWrite,
+ U: Encoder<I>,
+ U::Error: From<io::Error>,
+{
+ type Error = U::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+impl<T, U> fmt::Debug for Framed<T, U>
+where
+ T: fmt::Debug,
+ U: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Framed")
+ .field("io", self.get_ref())
+ .field("codec", self.codec())
+ .finish()
+ }
+}
+
+/// `FramedParts` contains an export of the data of a Framed transport.
+/// It can be used to construct a new [`Framed`] with a different codec.
+/// It contains all current buffers and the inner transport.
+///
+/// [`Framed`]: crate::codec::Framed
+#[derive(Debug)]
+#[allow(clippy::manual_non_exhaustive)]
+pub struct FramedParts<T, U> {
+ /// The inner transport used to read bytes to and write bytes to
+ pub io: T,
+
+ /// The codec
+ pub codec: U,
+
+ /// The buffer with read but unprocessed data.
+ pub read_buf: BytesMut,
+
+ /// A buffer with unprocessed data which are not written yet.
+ pub write_buf: BytesMut,
+
+ /// This private field allows us to add additional fields in the future in a
+ /// backwards compatible way.
+ _priv: (),
+}
+
+impl<T, U> FramedParts<T, U> {
+ /// Create a new, default, `FramedParts`
+ pub fn new<I>(io: T, codec: U) -> FramedParts<T, U>
+ where
+ U: Encoder<I>,
+ {
+ FramedParts {
+ io,
+ codec,
+ read_buf: BytesMut::new(),
+ write_buf: BytesMut::new(),
+ _priv: (),
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/framed_impl.rs b/third_party/rust/tokio-util/src/codec/framed_impl.rs
new file mode 100644
index 0000000000..ce1a6db873
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/framed_impl.rs
@@ -0,0 +1,308 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use futures_core::Stream;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use futures_core::ready;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::borrow::{Borrow, BorrowMut};
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tracing::trace;
+
+pin_project! {
+ #[derive(Debug)]
+ pub(crate) struct FramedImpl<T, U, State> {
+ #[pin]
+ pub(crate) inner: T,
+ pub(crate) state: State,
+ pub(crate) codec: U,
+ }
+}
+
+const INITIAL_CAPACITY: usize = 8 * 1024;
+const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
+
+#[derive(Debug)]
+pub(crate) struct ReadFrame {
+ pub(crate) eof: bool,
+ pub(crate) is_readable: bool,
+ pub(crate) buffer: BytesMut,
+ pub(crate) has_errored: bool,
+}
+
+pub(crate) struct WriteFrame {
+ pub(crate) buffer: BytesMut,
+}
+
+#[derive(Default)]
+pub(crate) struct RWFrames {
+ pub(crate) read: ReadFrame,
+ pub(crate) write: WriteFrame,
+}
+
+impl Default for ReadFrame {
+ fn default() -> Self {
+ Self {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ has_errored: false,
+ }
+ }
+}
+
+impl Default for WriteFrame {
+ fn default() -> Self {
+ Self {
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ }
+ }
+}
+
+impl From<BytesMut> for ReadFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self {
+ buffer,
+ is_readable: size > 0,
+ eof: false,
+ has_errored: false,
+ }
+ }
+}
+
+impl From<BytesMut> for WriteFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self { buffer }
+ }
+}
+
+impl Borrow<ReadFrame> for RWFrames {
+ fn borrow(&self) -> &ReadFrame {
+ &self.read
+ }
+}
+impl BorrowMut<ReadFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut ReadFrame {
+ &mut self.read
+ }
+}
+impl Borrow<WriteFrame> for RWFrames {
+ fn borrow(&self) -> &WriteFrame {
+ &self.write
+ }
+}
+impl BorrowMut<WriteFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut WriteFrame {
+ &mut self.write
+ }
+}
+impl<T, U, R> Stream for FramedImpl<T, U, R>
+where
+ T: AsyncRead,
+ U: Decoder,
+ R: BorrowMut<ReadFrame>,
+{
+ type Item = Result<U::Item, U::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ use crate::util::poll_read_buf;
+
+ let mut pinned = self.project();
+ let state: &mut ReadFrame = pinned.state.borrow_mut();
+ // The following loops implements a state machine with each state corresponding
+ // to a combination of the `is_readable` and `eof` flags. States persist across
+ // loop entries and most state transitions occur with a return.
+ //
+ // The initial state is `reading`.
+ //
+ // | state | eof | is_readable | has_errored |
+ // |---------|-------|-------------|-------------|
+ // | reading | false | false | false |
+ // | framing | false | true | false |
+ // | pausing | true | true | false |
+ // | paused | true | false | false |
+ // | errored | <any> | <any> | true |
+ // `decode_eof` returns Err
+ // ┌────────────────────────────────────────────────────────┐
+ // `decode_eof` returns │ │
+ // `Ok(Some)` │ │
+ // ┌─────┐ │ `decode_eof` returns After returning │
+ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
+ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
+ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
+ // Pending read │ │ │ │ │ │
+ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
+ // │ │ │ ┌──────┐ │ Pending │ │
+ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
+ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
+ // └──┬─▲────┘ └─────┬──┬┘ │ │
+ // │ │ │ │ `decode` returns Err │ │
+ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
+ // │ read returns Err │
+ // └────────────────────────────────────────────────────────────────────────────────────────────┘
+ loop {
+ // Return `None` if we have encountered an error from the underlying decoder
+ // See: https://github.com/tokio-rs/tokio/issues/3976
+ if state.has_errored {
+ // preparing has_errored -> paused
+ trace!("Returning None and setting paused");
+ state.is_readable = false;
+ state.has_errored = false;
+ return Poll::Ready(None);
+ }
+
+ // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
+ // i.e. it _might_ contain data consumable as a frame or closing frame.
+ // Both signal that there is no such data by returning `None`.
+ //
+ // If `decode` couldn't read a frame and the upstream source has returned eof,
+ // `decode_eof` will attempt to decode the remaining bytes as closing frames.
+ //
+ // If the underlying AsyncRead is resumable, we may continue after an EOF,
+ // but must finish emitting all of it's associated `decode_eof` frames.
+ // Furthermore, we don't want to emit any `decode_eof` frames on retried
+ // reads after an EOF unless we've actually read more data.
+ if state.is_readable {
+ // pausing or framing
+ if state.eof {
+ // pausing
+ let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ })?;
+ if frame.is_none() {
+ state.is_readable = false; // prepare pausing -> paused
+ }
+ // implicit pausing -> pausing or pausing -> paused
+ return Poll::Ready(frame.map(Ok));
+ }
+
+ // framing
+ trace!("attempting to decode a frame");
+
+ if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ op
+ })? {
+ trace!("frame decoded from buffer");
+ // implicit framing -> framing
+ return Poll::Ready(Some(Ok(frame)));
+ }
+
+ // framing -> reading
+ state.is_readable = false;
+ }
+ // reading or paused
+ // If we can't build a frame yet, try to read more data and try again.
+ // Make sure we've got room for at least one byte to read to ensure
+ // that we don't get a spurious 0 that looks like EOF.
+ state.buffer.reserve(1);
+ let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
+ |err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ },
+ )? {
+ Poll::Ready(ct) => ct,
+ // implicit reading -> reading or implicit paused -> paused
+ Poll::Pending => return Poll::Pending,
+ };
+ if bytect == 0 {
+ if state.eof {
+ // We're already at an EOF, and since we've reached this path
+ // we're also not readable. This implies that we've already finished
+ // our `decode_eof` handling, so we can simply return `None`.
+ // implicit paused -> paused
+ return Poll::Ready(None);
+ }
+ // prepare reading -> paused
+ state.eof = true;
+ } else {
+ // prepare paused -> framing or noop reading -> framing
+ state.eof = false;
+ }
+
+ // paused -> framing or reading -> framing or reading -> pausing
+ state.is_readable = true;
+ }
+ }
+}
+
+impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
+where
+ T: AsyncWrite,
+ U: Encoder<I>,
+ U::Error: From<io::Error>,
+ W: BorrowMut<WriteFrame>,
+{
+ type Error = U::Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY {
+ self.as_mut().poll_flush(cx)
+ } else {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ let pinned = self.project();
+ pinned
+ .codec
+ .encode(item, &mut pinned.state.borrow_mut().buffer)?;
+ Ok(())
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ use crate::util::poll_write_buf;
+ trace!("flushing framed transport");
+ let mut pinned = self.project();
+
+ while !pinned.state.borrow_mut().buffer.is_empty() {
+ let WriteFrame { buffer } = pinned.state.borrow_mut();
+ trace!(remaining = buffer.len(), "writing;");
+
+ let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
+
+ if n == 0 {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::WriteZero,
+ "failed to \
+ write frame to transport",
+ )
+ .into()));
+ }
+ }
+
+ // Try flushing the underlying IO
+ ready!(pinned.inner.poll_flush(cx))?;
+
+ trace!("framed transport flushed");
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ ready!(self.project().inner.poll_shutdown(cx))?;
+
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/framed_read.rs b/third_party/rust/tokio-util/src/codec/framed_read.rs
new file mode 100644
index 0000000000..184c567b49
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/framed_read.rs
@@ -0,0 +1,199 @@
+use crate::codec::framed_impl::{FramedImpl, ReadFrame};
+use crate::codec::Decoder;
+
+use futures_core::Stream;
+use tokio::io::AsyncRead;
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A [`Stream`] of messages decoded from an [`AsyncRead`].
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ pub struct FramedRead<T, D> {
+ #[pin]
+ inner: FramedImpl<T, D, ReadFrame>,
+ }
+}
+
+// ===== impl FramedRead =====
+
+impl<T, D> FramedRead<T, D>
+where
+ T: AsyncRead,
+ D: Decoder,
+{
+ /// Creates a new `FramedRead` with the given `decoder`.
+ pub fn new(inner: T, decoder: D) -> FramedRead<T, D> {
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ codec: decoder,
+ state: Default::default(),
+ },
+ }
+ }
+
+ /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity`
+ /// initial size.
+ pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead<T, D> {
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ codec: decoder,
+ state: ReadFrame {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(capacity),
+ has_errored: false,
+ },
+ },
+ }
+ }
+}
+
+impl<T, D> FramedRead<T, D> {
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Consumes the `FramedRead`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Returns a reference to the underlying decoder.
+ pub fn decoder(&self) -> &D {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying decoder.
+ pub fn decoder_mut(&mut self) -> &mut D {
+ &mut self.inner.codec
+ }
+
+ /// Maps the decoder `D` to `C`, preserving the read buffer
+ /// wrapped by `Framed`.
+ pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C>
+ where
+ F: FnOnce(D) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let FramedImpl {
+ inner,
+ state,
+ codec,
+ } = self.inner;
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ state,
+ codec: map(codec),
+ },
+ }
+ }
+
+ /// Returns a mutable reference to the underlying decoder.
+ pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.inner.state.buffer
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.buffer
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, D> Stream for FramedRead<T, D>
+where
+ T: AsyncRead,
+ D: Decoder,
+{
+ type Item = Result<D::Item, D::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
+
+// This impl just defers to the underlying T: Sink
+impl<T, I, D> Sink<I> for FramedRead<T, D>
+where
+ T: Sink<I>,
+{
+ type Error = T::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_close(cx)
+ }
+}
+
+impl<T, D> fmt::Debug for FramedRead<T, D>
+where
+ T: fmt::Debug,
+ D: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedRead")
+ .field("inner", &self.get_ref())
+ .field("decoder", &self.decoder())
+ .field("eof", &self.inner.state.eof)
+ .field("is_readable", &self.inner.state.is_readable)
+ .field("buffer", &self.read_buffer())
+ .finish()
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/framed_write.rs b/third_party/rust/tokio-util/src/codec/framed_write.rs
new file mode 100644
index 0000000000..aa4cec9820
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/framed_write.rs
@@ -0,0 +1,178 @@
+use crate::codec::encoder::Encoder;
+use crate::codec::framed_impl::{FramedImpl, WriteFrame};
+
+use futures_core::Stream;
+use tokio::io::AsyncWrite;
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A [`Sink`] of frames encoded to an `AsyncWrite`.
+ ///
+ /// [`Sink`]: futures_sink::Sink
+ pub struct FramedWrite<T, E> {
+ #[pin]
+ inner: FramedImpl<T, E, WriteFrame>,
+ }
+}
+
+impl<T, E> FramedWrite<T, E>
+where
+ T: AsyncWrite,
+{
+ /// Creates a new `FramedWrite` with the given `encoder`.
+ pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> {
+ FramedWrite {
+ inner: FramedImpl {
+ inner,
+ codec: encoder,
+ state: WriteFrame::default(),
+ },
+ }
+ }
+}
+
+impl<T, E> FramedWrite<T, E> {
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Consumes the `FramedWrite`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Returns a reference to the underlying encoder.
+ pub fn encoder(&self) -> &E {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying encoder.
+ pub fn encoder_mut(&mut self) -> &mut E {
+ &mut self.inner.codec
+ }
+
+ /// Maps the encoder `E` to `C`, preserving the write buffer
+ /// wrapped by `Framed`.
+ pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
+ where
+ F: FnOnce(E) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let FramedImpl {
+ inner,
+ state,
+ codec,
+ } = self.inner;
+ FramedWrite {
+ inner: FramedImpl {
+ inner,
+ state,
+ codec: map(codec),
+ },
+ }
+ }
+
+ /// Returns a mutable reference to the underlying encoder.
+ pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the write buffer.
+ pub fn write_buffer(&self) -> &BytesMut {
+ &self.inner.state.buffer
+ }
+
+ /// Returns a mutable reference to the write buffer.
+ pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.buffer
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, I, E> Sink<I> for FramedWrite<T, E>
+where
+ T: AsyncWrite,
+ E: Encoder<I>,
+ E::Error: From<io::Error>,
+{
+ type Error = E::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+// This impl just defers to the underlying T: Stream
+impl<T, D> Stream for FramedWrite<T, D>
+where
+ T: Stream,
+{
+ type Item = T::Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.project().inner.poll_next(cx)
+ }
+}
+
+impl<T, U> fmt::Debug for FramedWrite<T, U>
+where
+ T: fmt::Debug,
+ U: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedWrite")
+ .field("inner", &self.get_ref())
+ .field("encoder", &self.encoder())
+ .field("buffer", &self.inner.state.buffer)
+ .finish()
+ }
+}
diff --git a/third_party/rust/tokio-util/src/codec/length_delimited.rs b/third_party/rust/tokio-util/src/codec/length_delimited.rs
new file mode 100644
index 0000000000..93d2f180d0
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/length_delimited.rs
@@ -0,0 +1,1047 @@
+//! Frame a stream of bytes based on a length prefix
+//!
+//! Many protocols delimit their frames by prefacing frame data with a
+//! frame head that specifies the length of the frame. The
+//! `length_delimited` module provides utilities for handling the length
+//! based framing. This allows the consumer to work with entire frames
+//! without having to worry about buffering or other framing logic.
+//!
+//! # Getting started
+//!
+//! If implementing a protocol from scratch, using length delimited framing
+//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will
+//! return a length delimited codec using default configuration values.
+//! This can then be used to construct a framer to adapt a full-duplex
+//! byte stream into a stream of frames.
+//!
+//! ```
+//! use tokio::io::{AsyncRead, AsyncWrite};
+//! use tokio_util::codec::{Framed, LengthDelimitedCodec};
+//!
+//! fn bind_transport<T: AsyncRead + AsyncWrite>(io: T)
+//! -> Framed<T, LengthDelimitedCodec>
+//! {
+//! Framed::new(io, LengthDelimitedCodec::new())
+//! }
+//! # pub fn main() {}
+//! ```
+//!
+//! The returned transport implements `Sink + Stream` for `BytesMut`. It
+//! encodes the frame with a big-endian `u32` header denoting the frame
+//! payload length:
+//!
+//! ```text
+//! +----------+--------------------------------+
+//! | len: u32 | frame payload |
+//! +----------+--------------------------------+
+//! ```
+//!
+//! Specifically, given the following:
+//!
+//! ```
+//! use tokio::io::{AsyncRead, AsyncWrite};
+//! use tokio_util::codec::{Framed, LengthDelimitedCodec};
+//!
+//! use futures::SinkExt;
+//! use bytes::Bytes;
+//!
+//! async fn write_frame<T>(io: T) -> Result<(), Box<dyn std::error::Error>>
+//! where
+//! T: AsyncRead + AsyncWrite + Unpin,
+//! {
+//! let mut transport = Framed::new(io, LengthDelimitedCodec::new());
+//! let frame = Bytes::from("hello world");
+//!
+//! transport.send(frame).await?;
+//! Ok(())
+//! }
+//! ```
+//!
+//! The encoded frame will look like this:
+//!
+//! ```text
+//! +---- len: u32 ----+---- data ----+
+//! | \x00\x00\x00\x0b | hello world |
+//! +------------------+--------------+
+//! ```
+//!
+//! # Decoding
+//!
+//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`],
+//! such that each yielded [`BytesMut`] value contains the contents of an
+//! entire frame. There are many configuration parameters enabling
+//! [`FramedRead`] to handle a wide range of protocols. Here are some
+//! examples that will cover the various options at a high level.
+//!
+//! ## Example 1
+//!
+//! The following will parse a `u16` length field at offset 0, including the
+//! frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(0) // default value
+//! .num_skip(0) // Do not strip frame header
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+
+//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world |
+//! +----------+---------------+ +----------+---------------+
+//! ```
+//!
+//! The value of the length field is 11 (`\x0B`) which represents the length
+//! of the payload, `hello world`. By default, [`FramedRead`] assumes that
+//! the length field represents the number of bytes that **follows** the
+//! length field. Thus, the entire frame has a length of 13: 2 bytes for the
+//! frame head + 11 bytes for the payload.
+//!
+//! ## Example 2
+//!
+//! The following will parse a `u16` length field at offset 0, omitting the
+//! frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(0) // default value
+//! // `num_skip` is not needed, the default is to skip
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +--- Payload ---+
+//! | \x00\x0B | Hello world | --> | Hello world |
+//! +----------+---------------+ +---------------+
+//! ```
+//!
+//! This is similar to the first example, the only difference is that the
+//! frame head is **not** included in the yielded `BytesMut` value.
+//!
+//! ## Example 3
+//!
+//! The following will parse a `u16` length field at offset 0, including the
+//! frame head in the yielded `BytesMut`. In this case, the length field
+//! **includes** the frame head length.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(-2) // size of head
+//! .num_skip(0)
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+
+//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world |
+//! +----------+---------------+ +----------+---------------+
+//! ```
+//!
+//! In most cases, the length field represents the length of the payload
+//! only, as shown in the previous examples. However, in some protocols the
+//! length field represents the length of the whole frame, including the
+//! head. In such cases, we specify a negative `length_adjustment` to adjust
+//! the value provided in the frame head to represent the payload length.
+//!
+//! ## Example 4
+//!
+//! The following will parse a 3 byte length field at offset 0 in a 5 byte
+//! frame head, including the frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_length(3)
+//! .length_adjustment(2) // remaining head
+//! .num_skip(0)
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +---- len -----+- head -+--- Payload ---+
+//! | \x00\x00\x0B | \xCAFE | Hello world |
+//! +--------------+--------+---------------+
+//!
+//! DECODED
+//! +---- len -----+- head -+--- Payload ---+
+//! | \x00\x00\x0B | \xCAFE | Hello world |
+//! +--------------+--------+---------------+
+//! ```
+//!
+//! A more advanced example that shows a case where there is extra frame
+//! head data between the length field and the payload. In such cases, it is
+//! usually desirable to include the frame head as part of the yielded
+//! `BytesMut`. This lets consumers of the length delimited framer to
+//! process the frame head as needed.
+//!
+//! The positive `length_adjustment` value lets `FramedRead` factor in the
+//! additional head into the frame length calculation.
+//!
+//! ## Example 5
+//!
+//! The following will parse a `u16` length field at offset 1 of a 4 byte
+//! frame head. The first byte and the length field will be omitted from the
+//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be
+//! included.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(1) // length of hdr1
+//! .length_field_type::<u16>()
+//! .length_adjustment(1) // length of hdr2
+//! .num_skip(3) // length of hdr1 + LEN
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+
+//! | \xCA | \x00\x0B | \xFE | Hello world |
+//! +--------+----------+--------+---------------+
+//!
+//! DECODED
+//! +- hdr2 -+--- Payload ---+
+//! | \xFE | Hello world |
+//! +--------+---------------+
+//! ```
+//!
+//! The length field is situated in the middle of the frame head. In this
+//! case, the first byte in the frame head could be a version or some other
+//! identifier that is not needed for processing. On the other hand, the
+//! second half of the head is needed.
+//!
+//! `length_field_offset` indicates how many bytes to skip before starting
+//! to read the length field. `length_adjustment` is the number of bytes to
+//! skip starting at the end of the length field. In this case, it is the
+//! second half of the head.
+//!
+//! ## Example 6
+//!
+//! The following will parse a `u16` length field at offset 1 of a 4 byte
+//! frame head. The first byte and the length field will be omitted from the
+//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be
+//! included. In this case, the length field **includes** the frame head
+//! length.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(1) // length of hdr1
+//! .length_field_type::<u16>()
+//! .length_adjustment(-3) // length of hdr1 + LEN, negative
+//! .num_skip(3)
+//! .new_read(io);
+//! # }
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+
+//! | \xCA | \x00\x0F | \xFE | Hello world |
+//! +--------+----------+--------+---------------+
+//!
+//! DECODED
+//! +- hdr2 -+--- Payload ---+
+//! | \xFE | Hello world |
+//! +--------+---------------+
+//! ```
+//!
+//! Similar to the example above, the difference is that the length field
+//! represents the length of the entire frame instead of just the payload.
+//! The length of `hdr1` and `len` must be counted in `length_adjustment`.
+//! Note that the length of `hdr2` does **not** need to be explicitly set
+//! anywhere because it already is factored into the total frame length that
+//! is read from the byte stream.
+//!
+//! ## Example 7
+//!
+//! The following will parse a 3 byte length field at offset 0 in a 4 byte
+//! frame head, excluding the 4th byte from the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_length(3)
+//! .length_adjustment(0) // default value
+//! .num_skip(4) // skip the first 4 bytes
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +------- len ------+--- Payload ---+ +--- Payload ---+
+//! | \x00\x00\x0B\xFF | Hello world | => | Hello world |
+//! +------------------+---------------+ +---------------+
+//! ```
+//!
+//! A simple example where there are unused bytes between the length field
+//! and the payload.
+//!
+//! # Encoding
+//!
+//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`],
+//! such that each submitted [`BytesMut`] is prefaced by a length field.
+//! There are fewer configuration options than [`FramedRead`]. Given
+//! protocols that have more complex frame heads, an encoder should probably
+//! be written by hand using [`Encoder`].
+//!
+//! Here is a simple example, given a `FramedWrite` with the following
+//! configuration:
+//!
+//! ```
+//! # use tokio::io::AsyncWrite;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn write_frame<T: AsyncWrite>(io: T) {
+//! # let _ =
+//! LengthDelimitedCodec::builder()
+//! .length_field_type::<u16>()
+//! .new_write(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! A payload of `hello world` will be encoded as:
+//!
+//! ```text
+//! +- len: u16 -+---- data ----+
+//! | \x00\x0b | hello world |
+//! +------------+--------------+
+//! ```
+//!
+//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new
+//! [`FramedRead`]: struct@FramedRead
+//! [`FramedWrite`]: struct@FramedWrite
+//! [`AsyncRead`]: trait@tokio::io::AsyncRead
+//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite
+//! [`Encoder`]: trait@Encoder
+//! [`BytesMut`]: bytes::BytesMut
+
+use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};
+
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::error::Error as StdError;
+use std::io::{self, Cursor};
+use std::{cmp, fmt, mem};
+
+/// Configure length delimited `LengthDelimitedCodec`s.
+///
+/// `Builder` enables constructing configured length delimited codecs. Note
+/// that not all configuration settings apply to both encoding and decoding. See
+/// the documentation for specific methods for more detail.
+#[derive(Debug, Clone, Copy)]
+pub struct Builder {
+ // Maximum frame length
+ max_frame_len: usize,
+
+ // Number of bytes representing the field length
+ length_field_len: usize,
+
+ // Number of bytes in the header before the length field
+ length_field_offset: usize,
+
+ // Adjust the length specified in the header field by this amount
+ length_adjustment: isize,
+
+ // Total number of bytes to skip before reading the payload, if not set,
+ // `length_field_len + length_field_offset`
+ num_skip: Option<usize>,
+
+ // Length field byte order (little or big endian)
+ length_field_is_big_endian: bool,
+}
+
+/// An error when the number of bytes read is more than max frame length.
+pub struct LengthDelimitedCodecError {
+ _priv: (),
+}
+
+/// A codec for frames delimited by a frame head specifying their lengths.
+///
+/// This allows the consumer to work with entire frames without having to worry
+/// about buffering or other framing logic.
+///
+/// See [module level] documentation for more detail.
+///
+/// [module level]: index.html
+#[derive(Debug, Clone)]
+pub struct LengthDelimitedCodec {
+ // Configuration values
+ builder: Builder,
+
+ // Read state
+ state: DecodeState,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum DecodeState {
+ Head,
+ Data(usize),
+}
+
+// ===== impl LengthDelimitedCodec ======
+
+impl LengthDelimitedCodec {
+ /// Creates a new `LengthDelimitedCodec` with the default configuration values.
+ pub fn new() -> Self {
+ Self {
+ builder: Builder::new(),
+ state: DecodeState::Head,
+ }
+ }
+
+ /// Creates a new length delimited codec builder with default configuration
+ /// values.
+ pub fn builder() -> Builder {
+ Builder::new()
+ }
+
+ /// Returns the current max frame setting
+ ///
+ /// This is the largest size this codec will accept from the wire. Larger
+ /// frames will be rejected.
+ pub fn max_frame_length(&self) -> usize {
+ self.builder.max_frame_len
+ }
+
+ /// Updates the max frame setting.
+ ///
+ /// The change takes effect the next time a frame is decoded. In other
+ /// words, if a frame is currently in process of being decoded with a frame
+ /// size greater than `val` but less than the max frame length in effect
+ /// before calling this function, then the frame will be allowed.
+ pub fn set_max_frame_length(&mut self, val: usize) {
+ self.builder.max_frame_length(val);
+ }
+
+ fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
+ let head_len = self.builder.num_head_bytes();
+ let field_len = self.builder.length_field_len;
+
+ if src.len() < head_len {
+ // Not enough data
+ return Ok(None);
+ }
+
+ let n = {
+ let mut src = Cursor::new(&mut *src);
+
+ // Skip the required bytes
+ src.advance(self.builder.length_field_offset);
+
+ // match endianness
+ let n = if self.builder.length_field_is_big_endian {
+ src.get_uint(field_len)
+ } else {
+ src.get_uint_le(field_len)
+ };
+
+ if n > self.builder.max_frame_len as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ LengthDelimitedCodecError { _priv: () },
+ ));
+ }
+
+ // The check above ensures there is no overflow
+ let n = n as usize;
+
+ // Adjust `n` with bounds checking
+ let n = if self.builder.length_adjustment < 0 {
+ n.checked_sub(-self.builder.length_adjustment as usize)
+ } else {
+ n.checked_add(self.builder.length_adjustment as usize)
+ };
+
+ // Error handling
+ match n {
+ Some(n) => n,
+ None => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "provided length would overflow after adjustment",
+ ));
+ }
+ }
+ };
+
+ let num_skip = self.builder.get_num_skip();
+
+ if num_skip > 0 {
+ src.advance(num_skip);
+ }
+
+ // Ensure that the buffer has enough space to read the incoming
+ // payload
+ src.reserve(n);
+
+ Ok(Some(n))
+ }
+
+ fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option<BytesMut> {
+ // At this point, the buffer has already had the required capacity
+ // reserved. All there is to do is read.
+ if src.len() < n {
+ return None;
+ }
+
+ Some(src.split_to(n))
+ }
+}
+
+impl Decoder for LengthDelimitedCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
+ let n = match self.state {
+ DecodeState::Head => match self.decode_head(src)? {
+ Some(n) => {
+ self.state = DecodeState::Data(n);
+ n
+ }
+ None => return Ok(None),
+ },
+ DecodeState::Data(n) => n,
+ };
+
+ match self.decode_data(n, src) {
+ Some(data) => {
+ // Update the decode state
+ self.state = DecodeState::Head;
+
+ // Make sure the buffer has enough space to read the next head
+ src.reserve(self.builder.num_head_bytes());
+
+ Ok(Some(data))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+impl Encoder<Bytes> for LengthDelimitedCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
+ let n = data.len();
+
+ if n > self.builder.max_frame_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ LengthDelimitedCodecError { _priv: () },
+ ));
+ }
+
+ // Adjust `n` with bounds checking
+ let n = if self.builder.length_adjustment < 0 {
+ n.checked_add(-self.builder.length_adjustment as usize)
+ } else {
+ n.checked_sub(self.builder.length_adjustment as usize)
+ };
+
+ let n = n.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "provided length would overflow after adjustment",
+ )
+ })?;
+
+ // Reserve capacity in the destination buffer to fit the frame and
+ // length field (plus adjustment).
+ dst.reserve(self.builder.length_field_len + n);
+
+ if self.builder.length_field_is_big_endian {
+ dst.put_uint(n as u64, self.builder.length_field_len);
+ } else {
+ dst.put_uint_le(n as u64, self.builder.length_field_len);
+ }
+
+ // Write the frame to the buffer
+ dst.extend_from_slice(&data[..]);
+
+ Ok(())
+ }
+}
+
+impl Default for LengthDelimitedCodec {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// ===== impl Builder =====
+
+mod builder {
+ /// Types that can be used with `Builder::length_field_type`.
+ pub trait LengthFieldType {}
+
+ impl LengthFieldType for u8 {}
+ impl LengthFieldType for u16 {}
+ impl LengthFieldType for u32 {}
+ impl LengthFieldType for u64 {}
+
+ #[cfg(any(
+ target_pointer_width = "8",
+ target_pointer_width = "16",
+ target_pointer_width = "32",
+ target_pointer_width = "64",
+ ))]
+ impl LengthFieldType for usize {}
+}
+
+impl Builder {
+ /// Creates a new length delimited codec builder with default configuration
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new() -> Builder {
+ Builder {
+ // Default max frame length of 8MB
+ max_frame_len: 8 * 1_024 * 1_024,
+
+ // Default byte length of 4
+ length_field_len: 4,
+
+ // Default to the header field being at the start of the header.
+ length_field_offset: 0,
+
+ length_adjustment: 0,
+
+ // Total number of bytes to skip before reading the payload, if not set,
+ // `length_field_len + length_field_offset`
+ num_skip: None,
+
+ // Default to reading the length field in network (big) endian.
+ length_field_is_big_endian: true,
+ }
+ }
+
+ /// Read the length field as a big endian integer
+ ///
+ /// This is the default setting.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .big_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn big_endian(&mut self) -> &mut Self {
+ self.length_field_is_big_endian = true;
+ self
+ }
+
+ /// Read the length field as a little endian integer
+ ///
+ /// The default setting is big endian.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .little_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn little_endian(&mut self) -> &mut Self {
+ self.length_field_is_big_endian = false;
+ self
+ }
+
+ /// Read the length field as a native endian integer
+ ///
+ /// The default setting is big endian.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .native_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn native_endian(&mut self) -> &mut Self {
+ if cfg!(target_endian = "big") {
+ self.big_endian()
+ } else {
+ self.little_endian()
+ }
+ }
+
+ /// Sets the max frame length in bytes
+ ///
+ /// This configuration option applies to both encoding and decoding. The
+ /// default value is 8MB.
+ ///
+ /// When decoding, the length field read from the byte stream is checked
+ /// against this setting **before** any adjustments are applied. When
+ /// encoding, the length of the submitted payload is checked against this
+ /// setting.
+ ///
+ /// When frames exceed the max length, an `io::Error` with the custom value
+ /// of the `LengthDelimitedCodecError` type will be returned.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .max_frame_length(8 * 1024 * 1024)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn max_frame_length(&mut self, val: usize) -> &mut Self {
+ self.max_frame_len = val;
+ self
+ }
+
+ /// Sets the unsigned integer type used to represent the length field.
+ ///
+ /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on
+ /// 64-bit targets).
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u32>()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ ///
+ /// Unlike [`Builder::length_field_length`], this does not fail at runtime
+ /// and instead produces a compile error:
+ ///
+ /// ```compile_fail
+ /// # use tokio::io::AsyncRead;
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u128>()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_type<T: builder::LengthFieldType>(&mut self) -> &mut Self {
+ self.length_field_length(mem::size_of::<T>())
+ }
+
+ /// Sets the number of bytes used to represent the length field
+ ///
+ /// The default value is `4`. The max value is `8`.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_length(4)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_length(&mut self, val: usize) -> &mut Self {
+ assert!(val > 0 && val <= 8, "invalid length field length");
+ self.length_field_len = val;
+ self
+ }
+
+ /// Sets the number of bytes in the header before the length field
+ ///
+ /// This configuration option only applies to decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(1)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_offset(&mut self, val: usize) -> &mut Self {
+ self.length_field_offset = val;
+ self
+ }
+
+ /// Delta between the payload length specified in the header and the real
+ /// payload length
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_adjustment(-2)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_adjustment(&mut self, val: isize) -> &mut Self {
+ self.length_adjustment = val;
+ self
+ }
+
+ /// Sets the number of bytes to skip before reading the payload
+ ///
+ /// Default value is `length_field_len + length_field_offset`
+ ///
+ /// This configuration option only applies to decoding
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .num_skip(4)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn num_skip(&mut self, val: usize) -> &mut Self {
+ self.num_skip = Some(val);
+ self
+ }
+
+ /// Create a configured length delimited `LengthDelimitedCodec`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ /// # pub fn main() {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_codec();
+ /// # }
+ /// ```
+ pub fn new_codec(&self) -> LengthDelimitedCodec {
+ LengthDelimitedCodec {
+ builder: *self,
+ state: DecodeState::Head,
+ }
+ }
+
+ /// Create a configured length delimited `FramedRead`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec>
+ where
+ T: AsyncRead,
+ {
+ FramedRead::new(upstream, self.new_codec())
+ }
+
+ /// Create a configured length delimited `FramedWrite`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncWrite;
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn write_frame<T: AsyncWrite>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u16>()
+ /// .new_write(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec>
+ where
+ T: AsyncWrite,
+ {
+ FramedWrite::new(inner, self.new_codec())
+ }
+
+ /// Create a configured length delimited `Framed`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::{AsyncRead, AsyncWrite};
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) {
+ /// # let _ =
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u16>()
+ /// .new_framed(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec>
+ where
+ T: AsyncRead + AsyncWrite,
+ {
+ Framed::new(inner, self.new_codec())
+ }
+
+ fn num_head_bytes(&self) -> usize {
+ let num = self.length_field_offset + self.length_field_len;
+ cmp::max(num, self.num_skip.unwrap_or(0))
+ }
+
+ fn get_num_skip(&self) -> usize {
+ self.num_skip
+ .unwrap_or(self.length_field_offset + self.length_field_len)
+ }
+}
+
+impl Default for Builder {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// ===== impl LengthDelimitedCodecError =====
+
+impl fmt::Debug for LengthDelimitedCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("LengthDelimitedCodecError").finish()
+ }
+}
+
+impl fmt::Display for LengthDelimitedCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("frame size too big")
+ }
+}
+
+impl StdError for LengthDelimitedCodecError {}
diff --git a/third_party/rust/tokio-util/src/codec/lines_codec.rs b/third_party/rust/tokio-util/src/codec/lines_codec.rs
new file mode 100644
index 0000000000..7a0a8f0454
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/lines_codec.rs
@@ -0,0 +1,230 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{Buf, BufMut, BytesMut};
+use std::{cmp, fmt, io, str, usize};
+
+/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub struct LinesCodec {
+ // Stored index of the next index to examine for a `\n` character.
+ // This is used to optimize searching.
+ // For example, if `decode` was called with `abc`, it would hold `3`,
+ // because that is the next index to examine.
+ // The next time `decode` is called with `abcde\n`, the method will
+ // only look at `de\n` before returning.
+ next_index: usize,
+
+ /// The maximum length for a given line. If `usize::MAX`, lines will be
+ /// read until a `\n` character is reached.
+ max_length: usize,
+
+ /// Are we currently discarding the remainder of a line which was over
+ /// the length limit?
+ is_discarding: bool,
+}
+
+impl LinesCodec {
+ /// Returns a `LinesCodec` for splitting up data into lines.
+ ///
+ /// # Note
+ ///
+ /// The returned `LinesCodec` will not have an upper bound on the length
+ /// of a buffered line. See the documentation for [`new_with_max_length`]
+ /// for information on why this could be a potential security risk.
+ ///
+ /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length()
+ pub fn new() -> LinesCodec {
+ LinesCodec {
+ next_index: 0,
+ max_length: usize::MAX,
+ is_discarding: false,
+ }
+ }
+
+ /// Returns a `LinesCodec` with a maximum line length limit.
+ ///
+ /// If this is set, calls to `LinesCodec::decode` will return a
+ /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls
+ /// will discard up to `limit` bytes from that line until a newline
+ /// character is reached, returning `None` until the line over the limit
+ /// has been fully discarded. After that point, calls to `decode` will
+ /// function as normal.
+ ///
+ /// # Note
+ ///
+ /// Setting a length limit is highly recommended for any `LinesCodec` which
+ /// will be exposed to untrusted input. Otherwise, the size of the buffer
+ /// that holds the line currently being read is unbounded. An attacker could
+ /// exploit this unbounded buffer by sending an unbounded amount of input
+ /// without any `\n` characters, causing unbounded memory consumption.
+ ///
+ /// [`LinesCodecError`]: crate::codec::LinesCodecError
+ pub fn new_with_max_length(max_length: usize) -> Self {
+ LinesCodec {
+ max_length,
+ ..LinesCodec::new()
+ }
+ }
+
+ /// Returns the maximum line length when decoding.
+ ///
+ /// ```
+ /// use std::usize;
+ /// use tokio_util::codec::LinesCodec;
+ ///
+ /// let codec = LinesCodec::new();
+ /// assert_eq!(codec.max_length(), usize::MAX);
+ /// ```
+ /// ```
+ /// use tokio_util::codec::LinesCodec;
+ ///
+ /// let codec = LinesCodec::new_with_max_length(256);
+ /// assert_eq!(codec.max_length(), 256);
+ /// ```
+ pub fn max_length(&self) -> usize {
+ self.max_length
+ }
+}
+
+fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
+ str::from_utf8(buf)
+ .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
+}
+
+fn without_carriage_return(s: &[u8]) -> &[u8] {
+ if let Some(&b'\r') = s.last() {
+ &s[..s.len() - 1]
+ } else {
+ s
+ }
+}
+
+impl Decoder for LinesCodec {
+ type Item = String;
+ type Error = LinesCodecError;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
+ loop {
+ // Determine how far into the buffer we'll search for a newline. If
+ // there's no max_length set, we'll read to the end of the buffer.
+ let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
+
+ let newline_offset = buf[self.next_index..read_to]
+ .iter()
+ .position(|b| *b == b'\n');
+
+ match (self.is_discarding, newline_offset) {
+ (true, Some(offset)) => {
+ // If we found a newline, discard up to that offset and
+ // then stop discarding. On the next iteration, we'll try
+ // to read a line normally.
+ buf.advance(offset + self.next_index + 1);
+ self.is_discarding = false;
+ self.next_index = 0;
+ }
+ (true, None) => {
+ // Otherwise, we didn't find a newline, so we'll discard
+ // everything we read. On the next iteration, we'll continue
+ // discarding up to max_len bytes unless we find a newline.
+ buf.advance(read_to);
+ self.next_index = 0;
+ if buf.is_empty() {
+ return Ok(None);
+ }
+ }
+ (false, Some(offset)) => {
+ // Found a line!
+ let newline_index = offset + self.next_index;
+ self.next_index = 0;
+ let line = buf.split_to(newline_index + 1);
+ let line = &line[..line.len() - 1];
+ let line = without_carriage_return(line);
+ let line = utf8(line)?;
+ return Ok(Some(line.to_string()));
+ }
+ (false, None) if buf.len() > self.max_length => {
+ // Reached the maximum length without finding a
+ // newline, return an error and start discarding on the
+ // next call.
+ self.is_discarding = true;
+ return Err(LinesCodecError::MaxLineLengthExceeded);
+ }
+ (false, None) => {
+ // We didn't find a line or reach the length limit, so the next
+ // call will resume searching at the current offset.
+ self.next_index = read_to;
+ return Ok(None);
+ }
+ }
+ }
+ }
+
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
+ Ok(match self.decode(buf)? {
+ Some(frame) => Some(frame),
+ None => {
+ // No terminating newline - return remaining data, if any
+ if buf.is_empty() || buf == &b"\r"[..] {
+ None
+ } else {
+ let line = buf.split_to(buf.len());
+ let line = without_carriage_return(&line);
+ let line = utf8(line)?;
+ self.next_index = 0;
+ Some(line.to_string())
+ }
+ }
+ })
+ }
+}
+
+impl<T> Encoder<T> for LinesCodec
+where
+ T: AsRef<str>,
+{
+ type Error = LinesCodecError;
+
+ fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> {
+ let line = line.as_ref();
+ buf.reserve(line.len() + 1);
+ buf.put(line.as_bytes());
+ buf.put_u8(b'\n');
+ Ok(())
+ }
+}
+
+impl Default for LinesCodec {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// An error occurred while encoding or decoding a line.
+#[derive(Debug)]
+pub enum LinesCodecError {
+ /// The maximum line length was exceeded.
+ MaxLineLengthExceeded,
+ /// An IO error occurred.
+ Io(io::Error),
+}
+
+impl fmt::Display for LinesCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"),
+ LinesCodecError::Io(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+impl From<io::Error> for LinesCodecError {
+ fn from(e: io::Error) -> LinesCodecError {
+ LinesCodecError::Io(e)
+ }
+}
+
+impl std::error::Error for LinesCodecError {}
diff --git a/third_party/rust/tokio-util/src/codec/mod.rs b/third_party/rust/tokio-util/src/codec/mod.rs
new file mode 100644
index 0000000000..2295176bdc
--- /dev/null
+++ b/third_party/rust/tokio-util/src/codec/mod.rs
@@ -0,0 +1,290 @@
+//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink
+//!
+//! Raw I/O objects work with byte sequences, but higher-level code usually
+//! wants to batch these into meaningful chunks, called "frames".
+//!
+//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and
+//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`].
+//! Framed streams are also known as transports.
+//!
+//! # The Decoder trait
+//!
+//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an
+//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify
+//! how sequences of bytes are turned into a sequence of frames, and to
+//! determine where the boundaries between frames are. The job of the
+//! `FramedRead` is to repeatedly switch between reading more data from the IO
+//! resource, and asking the decoder whether we have received enough data to
+//! decode another frame of data.
+//!
+//! The main method on the `Decoder` trait is the [`decode`] method. This method
+//! takes as argument the data that has been read so far, and when it is called,
+//! it will be in one of the following situations:
+//!
+//! 1. The buffer contains less than a full frame.
+//! 2. The buffer contains exactly a full frame.
+//! 3. The buffer contains more than a full frame.
+//!
+//! In the first situation, the decoder should return `Ok(None)`.
+//!
+//! In the second situation, the decoder should clear the provided buffer and
+//! return `Ok(Some(the_decoded_frame))`.
+//!
+//! In the third situation, the decoder should use a method such as [`split_to`]
+//! or [`advance`] to modify the buffer such that the frame is removed from the
+//! buffer, but any data in the buffer after that frame should still remain in
+//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in
+//! this case.
+//!
+//! Finally the decoder may return an error if the data is invalid in some way.
+//! The decoder should _not_ return an error just because it has yet to receive
+//! a full frame.
+//!
+//! It is guaranteed that, from one call to `decode` to another, the provided
+//! buffer will contain the exact same data as before, except that if more data
+//! has arrived through the IO resource, that data will have been appended to
+//! the buffer. This means that reading frames from a `FramedRead` is
+//! essentially equivalent to the following loop:
+//!
+//! ```no_run
+//! use tokio::io::AsyncReadExt;
+//! # // This uses async_stream to create an example that compiles.
+//! # fn foo() -> impl futures_core::Stream<Item = std::io::Result<bytes::BytesMut>> { async_stream::try_stream! {
+//! # use tokio_util::codec::Decoder;
+//! # let mut decoder = tokio_util::codec::BytesCodec::new();
+//! # let io_resource = &mut &[0u8, 1, 2, 3][..];
+//!
+//! let mut buf = bytes::BytesMut::new();
+//! loop {
+//! // The read_buf call will append to buf rather than overwrite existing data.
+//! let len = io_resource.read_buf(&mut buf).await?;
+//!
+//! if len == 0 {
+//! while let Some(frame) = decoder.decode_eof(&mut buf)? {
+//! yield frame;
+//! }
+//! break;
+//! }
+//!
+//! while let Some(frame) = decoder.decode(&mut buf)? {
+//! yield frame;
+//! }
+//! }
+//! # }}
+//! ```
+//! The example above uses `yield` whenever the `Stream` produces an item.
+//!
+//! ## Example decoder
+//!
+//! As an example, consider a protocol that can be used to send strings where
+//! each frame is a four byte integer that contains the length of the frame,
+//! followed by that many bytes of string data. The decoder fails with an error
+//! if the string data is not valid utf-8 or too long.
+//!
+//! Such a decoder can be written like this:
+//! ```
+//! use tokio_util::codec::Decoder;
+//! use bytes::{BytesMut, Buf};
+//!
+//! struct MyStringDecoder {}
+//!
+//! const MAX: usize = 8 * 1024 * 1024;
+//!
+//! impl Decoder for MyStringDecoder {
+//! type Item = String;
+//! type Error = std::io::Error;
+//!
+//! fn decode(
+//! &mut self,
+//! src: &mut BytesMut
+//! ) -> Result<Option<Self::Item>, Self::Error> {
+//! if src.len() < 4 {
+//! // Not enough data to read length marker.
+//! return Ok(None);
+//! }
+//!
+//! // Read length marker.
+//! let mut length_bytes = [0u8; 4];
+//! length_bytes.copy_from_slice(&src[..4]);
+//! let length = u32::from_le_bytes(length_bytes) as usize;
+//!
+//! // Check that the length is not too large to avoid a denial of
+//! // service attack where the server runs out of memory.
+//! if length > MAX {
+//! return Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! format!("Frame of length {} is too large.", length)
+//! ));
+//! }
+//!
+//! if src.len() < 4 + length {
+//! // The full string has not yet arrived.
+//! //
+//! // We reserve more space in the buffer. This is not strictly
+//! // necessary, but is a good idea performance-wise.
+//! src.reserve(4 + length - src.len());
+//!
+//! // We inform the Framed that we need more bytes to form the next
+//! // frame.
+//! return Ok(None);
+//! }
+//!
+//! // Use advance to modify src such that it no longer contains
+//! // this frame.
+//! let data = src[4..4 + length].to_vec();
+//! src.advance(4 + length);
+//!
+//! // Convert the data to a string, or fail if it is not valid utf-8.
+//! match String::from_utf8(data) {
+//! Ok(string) => Ok(Some(string)),
+//! Err(utf8_error) => {
+//! Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! utf8_error.utf8_error(),
+//! ))
+//! },
+//! }
+//! }
+//! }
+//! ```
+//!
+//! # The Encoder trait
+//!
+//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn
+//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to
+//! specify how frames are turned into a sequences of bytes. The job of the
+//! `FramedWrite` is to take the resulting sequence of bytes and write it to the
+//! IO resource.
+//!
+//! The main method on the `Encoder` trait is the [`encode`] method. This method
+//! takes an item that is being written, and a buffer to write the item to. The
+//! buffer may already contain data, and in this case, the encoder should append
+//! the new frame the to buffer rather than overwrite the existing data.
+//!
+//! It is guaranteed that, from one call to `encode` to another, the provided
+//! buffer will contain the exact same data as before, except that some of the
+//! data may have been removed from the front of the buffer. Writing to a
+//! `FramedWrite` is essentially equivalent to the following loop:
+//!
+//! ```no_run
+//! use tokio::io::AsyncWriteExt;
+//! use bytes::Buf; // for advance
+//! # use tokio_util::codec::Encoder;
+//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() }
+//! # async fn no_more_frames() { }
+//! # #[tokio::main] async fn main() -> std::io::Result<()> {
+//! # let mut io_resource = tokio::io::sink();
+//! # let mut encoder = tokio_util::codec::BytesCodec::new();
+//!
+//! const MAX: usize = 8192;
+//!
+//! let mut buf = bytes::BytesMut::new();
+//! loop {
+//! tokio::select! {
+//! num_written = io_resource.write(&buf), if !buf.is_empty() => {
+//! buf.advance(num_written?);
+//! },
+//! frame = next_frame(), if buf.len() < MAX => {
+//! encoder.encode(frame, &mut buf)?;
+//! },
+//! _ = no_more_frames() => {
+//! io_resource.write_all(&buf).await?;
+//! io_resource.shutdown().await?;
+//! return Ok(());
+//! },
+//! }
+//! }
+//! # }
+//! ```
+//! Here the `next_frame` method corresponds to any frames you write to the
+//! `FramedWrite`. The `no_more_frames` method corresponds to closing the
+//! `FramedWrite` with [`SinkExt::close`].
+//!
+//! ## Example encoder
+//!
+//! As an example, consider a protocol that can be used to send strings where
+//! each frame is a four byte integer that contains the length of the frame,
+//! followed by that many bytes of string data. The encoder will fail if the
+//! string is too long.
+//!
+//! Such an encoder can be written like this:
+//! ```
+//! use tokio_util::codec::Encoder;
+//! use bytes::BytesMut;
+//!
+//! struct MyStringEncoder {}
+//!
+//! const MAX: usize = 8 * 1024 * 1024;
+//!
+//! impl Encoder<String> for MyStringEncoder {
+//! type Error = std::io::Error;
+//!
+//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> {
+//! // Don't send a string if it is longer than the other end will
+//! // accept.
+//! if item.len() > MAX {
+//! return Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! format!("Frame of length {} is too large.", item.len())
+//! ));
+//! }
+//!
+//! // Convert the length into a byte array.
+//! // The cast to u32 cannot overflow due to the length check above.
+//! let len_slice = u32::to_le_bytes(item.len() as u32);
+//!
+//! // Reserve space in the buffer.
+//! dst.reserve(4 + item.len());
+//!
+//! // Write the length and string to the buffer.
+//! dst.extend_from_slice(&len_slice);
+//! dst.extend_from_slice(item.as_bytes());
+//! Ok(())
+//! }
+//! }
+//! ```
+//!
+//! [`AsyncRead`]: tokio::io::AsyncRead
+//! [`AsyncWrite`]: tokio::io::AsyncWrite
+//! [`Stream`]: futures_core::Stream
+//! [`Sink`]: futures_sink::Sink
+//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close
+//! [`FramedRead`]: struct@crate::codec::FramedRead
+//! [`FramedWrite`]: struct@crate::codec::FramedWrite
+//! [`Framed`]: struct@crate::codec::Framed
+//! [`Decoder`]: trait@crate::codec::Decoder
+//! [`decode`]: fn@crate::codec::Decoder::decode
+//! [`encode`]: fn@crate::codec::Encoder::encode
+//! [`split_to`]: fn@bytes::BytesMut::split_to
+//! [`advance`]: fn@bytes::Buf::advance
+
+mod bytes_codec;
+pub use self::bytes_codec::BytesCodec;
+
+mod decoder;
+pub use self::decoder::Decoder;
+
+mod encoder;
+pub use self::encoder::Encoder;
+
+mod framed_impl;
+#[allow(unused_imports)]
+pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
+
+mod framed;
+pub use self::framed::{Framed, FramedParts};
+
+mod framed_read;
+pub use self::framed_read::FramedRead;
+
+mod framed_write;
+pub use self::framed_write::FramedWrite;
+
+pub mod length_delimited;
+pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError};
+
+mod lines_codec;
+pub use self::lines_codec::{LinesCodec, LinesCodecError};
+
+mod any_delimiter_codec;
+pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError};
diff --git a/third_party/rust/tokio-util/src/compat.rs b/third_party/rust/tokio-util/src/compat.rs
new file mode 100644
index 0000000000..6a8802d969
--- /dev/null
+++ b/third_party/rust/tokio-util/src/compat.rs
@@ -0,0 +1,274 @@
+//! Compatibility between the `tokio::io` and `futures-io` versions of the
+//! `AsyncRead` and `AsyncWrite` traits.
+use futures_core::ready;
+use pin_project_lite::pin_project;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A compatibility layer that allows conversion between the
+ /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
+ #[derive(Copy, Clone, Debug)]
+ pub struct Compat<T> {
+ #[pin]
+ inner: T,
+ seek_pos: Option<io::SeekFrom>,
+ }
+}
+
+/// Extension trait that allows converting a type implementing
+/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
+pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `tokio_io::AsyncRead`.
+ fn compat(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
+pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `tokio::io::AsyncWrite`.
+ fn compat_write(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
+pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `futures_io::AsyncRead`.
+ fn compat(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
+pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `futures_io::AsyncWrite`.
+ fn compat_write(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
+
+// === impl Compat ===
+
+impl<T> Compat<T> {
+ fn new(inner: T) -> Self {
+ Self {
+ inner,
+ seek_pos: None,
+ }
+ }
+
+ /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
+ /// contained within.
+ pub fn get_ref(&self) -> &T {
+ &self.inner
+ }
+
+ /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
+ /// contained within.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner
+ }
+
+ /// Returns the wrapped item.
+ pub fn into_inner(self) -> T {
+ self.inner
+ }
+}
+
+impl<T> tokio::io::AsyncRead for Compat<T>
+where
+ T: futures_io::AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ // We can't trust the inner type to not peak at the bytes,
+ // so we must defensively initialize the buffer.
+ let slice = buf.initialize_unfilled();
+ let n = ready!(futures_io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ slice
+ ))?;
+ buf.advance(n);
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<T> futures_io::AsyncRead for Compat<T>
+where
+ T: tokio::io::AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ slice: &mut [u8],
+ ) -> Poll<io::Result<usize>> {
+ let mut buf = tokio::io::ReadBuf::new(slice);
+ ready!(tokio::io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ &mut buf
+ ))?;
+ Poll::Ready(Ok(buf.filled().len()))
+ }
+}
+
+impl<T> tokio::io::AsyncBufRead for Compat<T>
+where
+ T: futures_io::AsyncBufRead,
+{
+ fn poll_fill_buf<'a>(
+ self: Pin<&'a mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<&'a [u8]>> {
+ futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ futures_io::AsyncBufRead::consume(self.project().inner, amt)
+ }
+}
+
+impl<T> futures_io::AsyncBufRead for Compat<T>
+where
+ T: tokio::io::AsyncBufRead,
+{
+ fn poll_fill_buf<'a>(
+ self: Pin<&'a mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<&'a [u8]>> {
+ tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ tokio::io::AsyncBufRead::consume(self.project().inner, amt)
+ }
+}
+
+impl<T> tokio::io::AsyncWrite for Compat<T>
+where
+ T: futures_io::AsyncWrite,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ futures_io::AsyncWrite::poll_close(self.project().inner, cx)
+ }
+}
+
+impl<T> futures_io::AsyncWrite for Compat<T>
+where
+ T: tokio::io::AsyncWrite,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
+ }
+}
+
+impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
+ fn poll_seek(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ pos: io::SeekFrom,
+ ) -> Poll<io::Result<u64>> {
+ if self.seek_pos != Some(pos) {
+ self.as_mut().project().inner.start_seek(pos)?;
+ *self.as_mut().project().seek_pos = Some(pos);
+ }
+ let res = ready!(self.as_mut().project().inner.poll_complete(cx));
+ *self.as_mut().project().seek_pos = None;
+ Poll::Ready(res.map(|p| p as u64))
+ }
+}
+
+impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
+ fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
+ *self.as_mut().project().seek_pos = Some(pos);
+ Ok(())
+ }
+
+ fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
+ let pos = match self.seek_pos {
+ None => {
+ // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
+ // We don't have to guarantee that the value returned by
+ // poll_complete called without start_seek is correct,
+ // so we'll return 0.
+ return Poll::Ready(Ok(0));
+ }
+ Some(pos) => pos,
+ };
+ let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
+ *self.as_mut().project().seek_pos = None;
+ Poll::Ready(res.map(|p| p as u64))
+ }
+}
+
+#[cfg(unix)]
+impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
+ fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
+ self.inner.as_raw_fd()
+ }
+}
+
+#[cfg(windows)]
+impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
+ fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
+ self.inner.as_raw_handle()
+ }
+}
diff --git a/third_party/rust/tokio-util/src/context.rs b/third_party/rust/tokio-util/src/context.rs
new file mode 100644
index 0000000000..a7a5e02949
--- /dev/null
+++ b/third_party/rust/tokio-util/src/context.rs
@@ -0,0 +1,190 @@
+//! Tokio context aware futures utilities.
+//!
+//! This module includes utilities around integrating tokio with other runtimes
+//! by allowing the context to be attached to futures. This allows spawning
+//! futures on other executors while still using tokio to drive them. This
+//! can be useful if you need to use a tokio based library in an executor/runtime
+//! that does not provide a tokio context.
+
+use pin_project_lite::pin_project;
+use std::{
+ future::Future,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::runtime::{Handle, Runtime};
+
+pin_project! {
+ /// `TokioContext` allows running futures that must be inside Tokio's
+ /// context on a non-Tokio runtime.
+ ///
+ /// It contains a [`Handle`] to the runtime. A handle to the runtime can be
+ /// obtain by calling the [`Runtime::handle()`] method.
+ ///
+ /// Note that the `TokioContext` wrapper only works if the `Runtime` it is
+ /// connected to has not yet been destroyed. You must keep the `Runtime`
+ /// alive until the future has finished executing.
+ ///
+ /// **Warning:** If `TokioContext` is used together with a [current thread]
+ /// runtime, that runtime must be inside a call to `block_on` for the
+ /// wrapped future to work. For this reason, it is recommended to use a
+ /// [multi thread] runtime, even if you configure it to only spawn one
+ /// worker thread.
+ ///
+ /// # Examples
+ ///
+ /// This example creates two runtimes, but only [enables time] on one of
+ /// them. It then uses the context of the runtime with the timer enabled to
+ /// execute a [`sleep`] future on the runtime with timing disabled.
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::RuntimeExt;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // Wrap the sleep future in the context of rt.
+ /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ ///
+ /// [`Handle`]: struct@tokio::runtime::Handle
+ /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle
+ /// [`RuntimeExt`]: trait@crate::context::RuntimeExt
+ /// [`new_static`]: fn@Self::new_static
+ /// [`sleep`]: fn@tokio::time::sleep
+ /// [current thread]: fn@tokio::runtime::Builder::new_current_thread
+ /// [enables time]: fn@tokio::runtime::Builder::enable_time
+ /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread
+ pub struct TokioContext<F> {
+ #[pin]
+ inner: F,
+ handle: Handle,
+ }
+}
+
+impl<F> TokioContext<F> {
+ /// Associate the provided future with the context of the runtime behind
+ /// the provided `Handle`.
+ ///
+ /// This constructor uses a `'static` lifetime to opt-out of checking that
+ /// the runtime still exists.
+ ///
+ /// # Examples
+ ///
+ /// This is the same as the example above, but uses the `new` constructor
+ /// rather than [`RuntimeExt::wrap`].
+ ///
+ /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap
+ ///
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::TokioContext;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// let fut = TokioContext::new(
+ /// async { sleep(Duration::from_millis(2)).await },
+ /// rt.handle().clone(),
+ /// );
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ pub fn new(future: F, handle: Handle) -> TokioContext<F> {
+ TokioContext {
+ inner: future,
+ handle,
+ }
+ }
+
+ /// Obtain a reference to the handle inside this `TokioContext`.
+ pub fn handle(&self) -> &Handle {
+ &self.handle
+ }
+
+ /// Remove the association between the Tokio runtime and the wrapped future.
+ pub fn into_inner(self) -> F {
+ self.inner
+ }
+}
+
+impl<F: Future> Future for TokioContext<F> {
+ type Output = F::Output;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ let handle = me.handle;
+ let fut = me.inner;
+
+ let _enter = handle.enter();
+ fut.poll(cx)
+ }
+}
+
+/// Extension trait that simplifies bundling a `Handle` with a `Future`.
+pub trait RuntimeExt {
+ /// Create a [`TokioContext`] that wraps the provided future and runs it in
+ /// this runtime's context.
+ ///
+ /// # Examples
+ ///
+ /// This example creates two runtimes, but only [enables time] on one of
+ /// them. It then uses the context of the runtime with the timer enabled to
+ /// execute a [`sleep`] future on the runtime with timing disabled.
+ ///
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::RuntimeExt;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // Wrap the sleep future in the context of rt.
+ /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ ///
+ /// [`TokioContext`]: struct@crate::context::TokioContext
+ /// [`sleep`]: fn@tokio::time::sleep
+ /// [enables time]: fn@tokio::runtime::Builder::enable_time
+ fn wrap<F: Future>(&self, fut: F) -> TokioContext<F>;
+}
+
+impl RuntimeExt for Runtime {
+ fn wrap<F: Future>(&self, fut: F) -> TokioContext<F> {
+ TokioContext {
+ inner: fut,
+ handle: self.handle().clone(),
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/either.rs b/third_party/rust/tokio-util/src/either.rs
new file mode 100644
index 0000000000..9225e53ca6
--- /dev/null
+++ b/third_party/rust/tokio-util/src/either.rs
@@ -0,0 +1,188 @@
+//! Module defining an Either type.
+use std::{
+ future::Future,
+ io::SeekFrom,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
+
+/// Combines two different futures, streams, or sinks having the same associated types into a single type.
+///
+/// This type implements common asynchronous traits such as [`Future`] and those in Tokio.
+///
+/// [`Future`]: std::future::Future
+///
+/// # Example
+///
+/// The following code will not work:
+///
+/// ```compile_fail
+/// # fn some_condition() -> bool { true }
+/// # async fn some_async_function() -> u32 { 10 }
+/// # async fn other_async_function() -> u32 { 20 }
+/// #[tokio::main]
+/// async fn main() {
+/// let result = if some_condition() {
+/// some_async_function()
+/// } else {
+/// other_async_function() // <- Will print: "`if` and `else` have incompatible types"
+/// };
+///
+/// println!("Result is {}", result.await);
+/// }
+/// ```
+///
+// This is because although the output types for both futures is the same, the exact future
+// types are different, but the compiler must be able to choose a single type for the
+// `result` variable.
+///
+/// When the output type is the same, we can wrap each future in `Either` to avoid the
+/// issue:
+///
+/// ```
+/// use tokio_util::either::Either;
+/// # fn some_condition() -> bool { true }
+/// # async fn some_async_function() -> u32 { 10 }
+/// # async fn other_async_function() -> u32 { 20 }
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let result = if some_condition() {
+/// Either::Left(some_async_function())
+/// } else {
+/// Either::Right(other_async_function())
+/// };
+///
+/// let value = result.await;
+/// println!("Result is {}", value);
+/// # assert_eq!(value, 10);
+/// }
+/// ```
+#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense.
+#[derive(Debug, Clone)]
+pub enum Either<L, R> {
+ Left(L),
+ Right(R),
+}
+
+/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation.
+/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either
+/// enum variant held in `self`.
+macro_rules! delegate_call {
+ ($self:ident.$method:ident($($args:ident),+)) => {
+ unsafe {
+ match $self.get_unchecked_mut() {
+ Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
+ Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
+ }
+ }
+ }
+}
+
+impl<L, R, O> Future for Either<L, R>
+where
+ L: Future<Output = O>,
+ R: Future<Output = O>,
+{
+ type Output = O;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ delegate_call!(self.poll(cx))
+ }
+}
+
+impl<L, R> AsyncRead for Either<L, R>
+where
+ L: AsyncRead,
+ R: AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<()>> {
+ delegate_call!(self.poll_read(cx, buf))
+ }
+}
+
+impl<L, R> AsyncBufRead for Either<L, R>
+where
+ L: AsyncBufRead,
+ R: AsyncBufRead,
+{
+ fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
+ delegate_call!(self.poll_fill_buf(cx))
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ delegate_call!(self.consume(amt))
+ }
+}
+
+impl<L, R> AsyncSeek for Either<L, R>
+where
+ L: AsyncSeek,
+ R: AsyncSeek,
+{
+ fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
+ delegate_call!(self.start_seek(position))
+ }
+
+ fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
+ delegate_call!(self.poll_complete(cx))
+ }
+}
+
+impl<L, R> AsyncWrite for Either<L, R>
+where
+ L: AsyncWrite,
+ R: AsyncWrite,
+{
+ fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
+ delegate_call!(self.poll_write(cx, buf))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
+ delegate_call!(self.poll_flush(cx))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
+ delegate_call!(self.poll_shutdown(cx))
+ }
+}
+
+impl<L, R> futures_core::stream::Stream for Either<L, R>
+where
+ L: futures_core::stream::Stream,
+ R: futures_core::stream::Stream<Item = L::Item>,
+{
+ type Item = L::Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ delegate_call!(self.poll_next(cx))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tokio::io::{repeat, AsyncReadExt, Repeat};
+ use tokio_stream::{once, Once, StreamExt};
+
+ #[tokio::test]
+ async fn either_is_stream() {
+ let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
+
+ assert_eq!(Some(1u32), either.next().await);
+ }
+
+ #[tokio::test]
+ async fn either_is_async_read() {
+ let mut buffer = [0; 3];
+ let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
+
+ either.read_exact(&mut buffer).await.unwrap();
+ assert_eq!(buffer, [0b101, 0b101, 0b101]);
+ }
+}
diff --git a/third_party/rust/tokio-util/src/io/mod.rs b/third_party/rust/tokio-util/src/io/mod.rs
new file mode 100644
index 0000000000..eb48a21fb9
--- /dev/null
+++ b/third_party/rust/tokio-util/src/io/mod.rs
@@ -0,0 +1,24 @@
+//! Helpers for IO related tasks.
+//!
+//! The stream types are often used in combination with hyper or reqwest, as they
+//! allow converting between a hyper [`Body`] and [`AsyncRead`].
+//!
+//! The [`SyncIoBridge`] type converts from the world of async I/O
+//! to synchronous I/O; this may often come up when using synchronous APIs
+//! inside [`tokio::task::spawn_blocking`].
+//!
+//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
+//! [`AsyncRead`]: tokio::io::AsyncRead
+
+mod read_buf;
+mod reader_stream;
+mod stream_reader;
+cfg_io_util! {
+ mod sync_bridge;
+ pub use self::sync_bridge::SyncIoBridge;
+}
+
+pub use self::read_buf::read_buf;
+pub use self::reader_stream::ReaderStream;
+pub use self::stream_reader::StreamReader;
+pub use crate::util::{poll_read_buf, poll_write_buf};
diff --git a/third_party/rust/tokio-util/src/io/read_buf.rs b/third_party/rust/tokio-util/src/io/read_buf.rs
new file mode 100644
index 0000000000..d7938a3bc1
--- /dev/null
+++ b/third_party/rust/tokio-util/src/io/read_buf.rs
@@ -0,0 +1,65 @@
+use bytes::BufMut;
+use std::future::Future;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+
+/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
+///
+/// [`BufMut`]: bytes::BufMut
+///
+/// # Example
+///
+/// ```
+/// use bytes::{Bytes, BytesMut};
+/// use tokio_stream as stream;
+/// use tokio::io::Result;
+/// use tokio_util::io::{StreamReader, read_buf};
+/// # #[tokio::main]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a reader from an iterator. This particular reader will always be
+/// // ready.
+/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
+///
+/// let mut buf = BytesMut::new();
+/// let mut reads = 0;
+///
+/// loop {
+/// reads += 1;
+/// let n = read_buf(&mut read, &mut buf).await?;
+///
+/// if n == 0 {
+/// break;
+/// }
+/// }
+///
+/// // one or more reads might be necessary.
+/// assert!(reads >= 1);
+/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
+/// # Ok(())
+/// # }
+/// ```
+pub async fn read_buf<R, B>(read: &mut R, buf: &mut B) -> io::Result<usize>
+where
+ R: AsyncRead + Unpin,
+ B: BufMut,
+{
+ return ReadBufFn(read, buf).await;
+
+ struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B);
+
+ impl<'a, R, B> Future for ReadBufFn<'a, R, B>
+ where
+ R: AsyncRead + Unpin,
+ B: BufMut,
+ {
+ type Output = io::Result<usize>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let this = &mut *self;
+ crate::util::poll_read_buf(Pin::new(this.0), cx, this.1)
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/io/reader_stream.rs b/third_party/rust/tokio-util/src/io/reader_stream.rs
new file mode 100644
index 0000000000..866c11408d
--- /dev/null
+++ b/third_party/rust/tokio-util/src/io/reader_stream.rs
@@ -0,0 +1,118 @@
+use bytes::{Bytes, BytesMut};
+use futures_core::stream::Stream;
+use pin_project_lite::pin_project;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+
+const DEFAULT_CAPACITY: usize = 4096;
+
+pin_project! {
+ /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
+ ///
+ /// This stream is fused. It performs the inverse operation of
+ /// [`StreamReader`].
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// use tokio_stream::StreamExt;
+ /// use tokio_util::io::ReaderStream;
+ ///
+ /// // Create a stream of data.
+ /// let data = b"hello, world!";
+ /// let mut stream = ReaderStream::new(&data[..]);
+ ///
+ /// // Read all of the chunks into a vector.
+ /// let mut stream_contents = Vec::new();
+ /// while let Some(chunk) = stream.next().await {
+ /// stream_contents.extend_from_slice(&chunk?);
+ /// }
+ ///
+ /// // Once the chunks are concatenated, we should have the
+ /// // original data.
+ /// assert_eq!(stream_contents, data);
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`StreamReader`]: crate::io::StreamReader
+ /// [`Stream`]: futures_core::Stream
+ #[derive(Debug)]
+ pub struct ReaderStream<R> {
+ // Reader itself.
+ //
+ // This value is `None` if the stream has terminated.
+ #[pin]
+ reader: Option<R>,
+ // Working buffer, used to optimize allocations.
+ buf: BytesMut,
+ capacity: usize,
+ }
+}
+
+impl<R: AsyncRead> ReaderStream<R> {
+ /// Convert an [`AsyncRead`] into a [`Stream`] with item type
+ /// `Result<Bytes, std::io::Error>`.
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: futures_core::Stream
+ pub fn new(reader: R) -> Self {
+ ReaderStream {
+ reader: Some(reader),
+ buf: BytesMut::new(),
+ capacity: DEFAULT_CAPACITY,
+ }
+ }
+
+ /// Convert an [`AsyncRead`] into a [`Stream`] with item type
+ /// `Result<Bytes, std::io::Error>`,
+ /// with a specific read buffer initial capacity.
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: futures_core::Stream
+ pub fn with_capacity(reader: R, capacity: usize) -> Self {
+ ReaderStream {
+ reader: Some(reader),
+ buf: BytesMut::with_capacity(capacity),
+ capacity,
+ }
+ }
+}
+
+impl<R: AsyncRead> Stream for ReaderStream<R> {
+ type Item = std::io::Result<Bytes>;
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ use crate::util::poll_read_buf;
+
+ let mut this = self.as_mut().project();
+
+ let reader = match this.reader.as_pin_mut() {
+ Some(r) => r,
+ None => return Poll::Ready(None),
+ };
+
+ if this.buf.capacity() == 0 {
+ this.buf.reserve(*this.capacity);
+ }
+
+ match poll_read_buf(reader, cx, &mut this.buf) {
+ Poll::Pending => Poll::Pending,
+ Poll::Ready(Err(err)) => {
+ self.project().reader.set(None);
+ Poll::Ready(Some(Err(err)))
+ }
+ Poll::Ready(Ok(0)) => {
+ self.project().reader.set(None);
+ Poll::Ready(None)
+ }
+ Poll::Ready(Ok(_)) => {
+ let chunk = this.buf.split();
+ Poll::Ready(Some(Ok(chunk.freeze())))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/io/stream_reader.rs b/third_party/rust/tokio-util/src/io/stream_reader.rs
new file mode 100644
index 0000000000..05ae886557
--- /dev/null
+++ b/third_party/rust/tokio-util/src/io/stream_reader.rs
@@ -0,0 +1,203 @@
+use bytes::Buf;
+use futures_core::stream::Stream;
+use pin_project_lite::pin_project;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
+
+pin_project! {
+ /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
+ ///
+ /// This type performs the inverse operation of [`ReaderStream`].
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use bytes::Bytes;
+ /// use tokio::io::{AsyncReadExt, Result};
+ /// use tokio_util::io::StreamReader;
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ ///
+ /// // Create a stream from an iterator.
+ /// let stream = tokio_stream::iter(vec![
+ /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
+ /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
+ /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
+ /// ]);
+ ///
+ /// // Convert it to an AsyncRead.
+ /// let mut read = StreamReader::new(stream);
+ ///
+ /// // Read five bytes from the stream.
+ /// let mut buf = [0; 5];
+ /// read.read_exact(&mut buf).await?;
+ /// assert_eq!(buf, [0, 1, 2, 3, 4]);
+ ///
+ /// // Read the rest of the current chunk.
+ /// assert_eq!(read.read(&mut buf).await?, 3);
+ /// assert_eq!(&buf[..3], [5, 6, 7]);
+ ///
+ /// // Read the next chunk.
+ /// assert_eq!(read.read(&mut buf).await?, 4);
+ /// assert_eq!(&buf[..4], [8, 9, 10, 11]);
+ ///
+ /// // We have now reached the end.
+ /// assert_eq!(read.read(&mut buf).await?, 0);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: futures_core::Stream
+ /// [`ReaderStream`]: crate::io::ReaderStream
+ #[derive(Debug)]
+ pub struct StreamReader<S, B> {
+ #[pin]
+ inner: S,
+ chunk: Option<B>,
+ }
+}
+
+impl<S, B, E> StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead).
+ ///
+ /// The item should be a [`Result`] with the ok variant being something that
+ /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
+ /// should be convertible into an [io error].
+ ///
+ /// [`Result`]: std::result::Result
+ /// [`Buf`]: bytes::Buf
+ /// [io error]: std::io::Error
+ pub fn new(stream: S) -> Self {
+ Self {
+ inner: stream,
+ chunk: None,
+ }
+ }
+
+ /// Do we have a chunk and is it non-empty?
+ fn has_chunk(&self) -> bool {
+ if let Some(ref chunk) = self.chunk {
+ chunk.remaining() > 0
+ } else {
+ false
+ }
+ }
+
+ /// Consumes this `StreamReader`, returning a Tuple consisting
+ /// of the underlying stream and an Option of the interal buffer,
+ /// which is Some in case the buffer contains elements.
+ pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
+ if self.has_chunk() {
+ (self.inner, self.chunk)
+ } else {
+ (self.inner, None)
+ }
+ }
+}
+
+impl<S, B> StreamReader<S, B> {
+ /// Gets a reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_ref(&self) -> &S {
+ &self.inner
+ }
+
+ /// Gets a mutable reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_mut(&mut self) -> &mut S {
+ &mut self.inner
+ }
+
+ /// Gets a pinned mutable reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
+ self.project().inner
+ }
+
+ /// Consumes this `BufWriter`, returning the underlying stream.
+ ///
+ /// Note that any leftover data in the internal buffer is lost.
+ /// If you additionally want access to the internal buffer use
+ /// [`into_inner_with_chunk`].
+ ///
+ /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
+ pub fn into_inner(self) -> S {
+ self.inner
+ }
+}
+
+impl<S, B, E> AsyncRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ if buf.remaining() == 0 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let inner_buf = match self.as_mut().poll_fill_buf(cx) {
+ Poll::Ready(Ok(buf)) => buf,
+ Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+ Poll::Pending => return Poll::Pending,
+ };
+ let len = std::cmp::min(inner_buf.len(), buf.remaining());
+ buf.put_slice(&inner_buf[..len]);
+
+ self.consume(len);
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<S, B, E> AsyncBufRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
+ loop {
+ if self.as_mut().has_chunk() {
+ // This unwrap is very sad, but it can't be avoided.
+ let buf = self.project().chunk.as_ref().unwrap().chunk();
+ return Poll::Ready(Ok(buf));
+ } else {
+ match self.as_mut().project().inner.poll_next(cx) {
+ Poll::Ready(Some(Ok(chunk))) => {
+ // Go around the loop in case the chunk is empty.
+ *self.as_mut().project().chunk = Some(chunk);
+ }
+ Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
+ Poll::Ready(None) => return Poll::Ready(Ok(&[])),
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ }
+ }
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ if amt > 0 {
+ self.project()
+ .chunk
+ .as_mut()
+ .expect("No chunk present")
+ .advance(amt);
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/io/sync_bridge.rs b/third_party/rust/tokio-util/src/io/sync_bridge.rs
new file mode 100644
index 0000000000..9be9446a7d
--- /dev/null
+++ b/third_party/rust/tokio-util/src/io/sync_bridge.rs
@@ -0,0 +1,103 @@
+use std::io::{Read, Write};
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
+
+/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+#[derive(Debug)]
+pub struct SyncIoBridge<T> {
+ src: T,
+ rt: tokio::runtime::Handle,
+}
+
+impl<T: AsyncRead + Unpin> Read for SyncIoBridge<T> {
+ fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(AsyncReadExt::read(src, buf))
+ }
+
+ fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.read_to_end(buf))
+ }
+
+ fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.read_to_string(buf))
+ }
+
+ fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
+ let src = &mut self.src;
+ // The AsyncRead trait returns the count, synchronous doesn't.
+ let _n = self.rt.block_on(src.read_exact(buf))?;
+ Ok(())
+ }
+}
+
+impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> {
+ fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write(buf))
+ }
+
+ fn flush(&mut self) -> std::io::Result<()> {
+ let src = &mut self.src;
+ self.rt.block_on(src.flush())
+ }
+
+ fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write_all(buf))
+ }
+
+ fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write_vectored(bufs))
+ }
+}
+
+// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time
+// of this writing still unstable, we expose this as part of a standalone method.
+impl<T: AsyncWrite> SyncIoBridge<T> {
+ /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes.
+ ///
+ /// See [`tokio::io::AsyncWrite::is_write_vectored`].
+ pub fn is_write_vectored(&self) -> bool {
+ self.src.is_write_vectored()
+ }
+}
+
+impl<T: Unpin> SyncIoBridge<T> {
+ /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+ /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+ ///
+ /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`].
+ /// It is hence OK to move this struct into a separate thread outside the runtime, as created
+ /// by e.g. [`tokio::task::spawn_blocking`].
+ ///
+ /// Stated even more strongly: to make use of this bridge, you *must* move
+ /// it into a separate thread outside the runtime. The synchronous I/O will use the
+ /// underlying handle to block on the backing asynchronous source, via
+ /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that
+ /// function, an attempt to `block_on` from an asynchronous execution context
+ /// will panic.
+ ///
+ /// # Wrapping `!Unpin` types
+ ///
+ /// Use e.g. `SyncIoBridge::new(Box::pin(src))`.
+ ///
+ /// # Panic
+ ///
+ /// This will panic if called outside the context of a Tokio runtime.
+ pub fn new(src: T) -> Self {
+ Self::new_with_handle(src, tokio::runtime::Handle::current())
+ }
+
+ /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+ /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+ ///
+ /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may
+ /// be initially invoked outside of an asynchronous context.
+ pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self {
+ Self { src, rt }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/lib.rs b/third_party/rust/tokio-util/src/lib.rs
new file mode 100644
index 0000000000..fd14a8ac94
--- /dev/null
+++ b/third_party/rust/tokio-util/src/lib.rs
@@ -0,0 +1,201 @@
+#![allow(clippy::needless_doctest_main)]
+#![warn(
+ missing_debug_implementations,
+ missing_docs,
+ rust_2018_idioms,
+ unreachable_pub
+)]
+#![doc(test(
+ no_crate_inject,
+ attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
+))]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+
+//! Utilities for working with Tokio.
+//!
+//! This crate is not versioned in lockstep with the core
+//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's
+//! semantic versioning policy, especially with regard to breaking changes.
+//!
+//! [`tokio`]: https://docs.rs/tokio
+
+#[macro_use]
+mod cfg;
+
+mod loom;
+
+cfg_codec! {
+ pub mod codec;
+}
+
+cfg_net! {
+ pub mod udp;
+ pub mod net;
+}
+
+cfg_compat! {
+ pub mod compat;
+}
+
+cfg_io! {
+ pub mod io;
+}
+
+cfg_rt! {
+ pub mod context;
+ pub mod task;
+}
+
+cfg_time! {
+ pub mod time;
+}
+
+pub mod sync;
+
+pub mod either;
+
+#[cfg(any(feature = "io", feature = "codec"))]
+mod util {
+ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+ use bytes::{Buf, BufMut};
+ use futures_core::ready;
+ use std::io::{self, IoSlice};
+ use std::mem::MaybeUninit;
+ use std::pin::Pin;
+ use std::task::{Context, Poll};
+
+ /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
+ ///
+ /// [`BufMut`]: bytes::Buf
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use bytes::{Bytes, BytesMut};
+ /// use tokio_stream as stream;
+ /// use tokio::io::Result;
+ /// use tokio_util::io::{StreamReader, poll_read_buf};
+ /// use futures::future::poll_fn;
+ /// use std::pin::Pin;
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ ///
+ /// // Create a reader from an iterator. This particular reader will always be
+ /// // ready.
+ /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
+ ///
+ /// let mut buf = BytesMut::new();
+ /// let mut reads = 0;
+ ///
+ /// loop {
+ /// reads += 1;
+ /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
+ ///
+ /// if n == 0 {
+ /// break;
+ /// }
+ /// }
+ ///
+ /// // one or more reads might be necessary.
+ /// assert!(reads >= 1);
+ /// assert_eq!(&buf[..], &[0, 1, 2, 3]);
+ /// # Ok(())
+ /// # }
+ /// ```
+ #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+ pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ if !buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = {
+ let dst = buf.chunk_mut();
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ let ptr = buf.filled().as_ptr();
+ ready!(io.poll_read(cx, &mut buf)?);
+
+ // Ensure the pointer does not change from under us
+ assert_eq!(ptr, buf.filled().as_ptr());
+ buf.filled().len()
+ };
+
+ // Safety: This is guaranteed to be the number of initialized (and read)
+ // bytes due to the invariants provided by `ReadBuf::filled`.
+ unsafe {
+ buf.advance_mut(n);
+ }
+
+ Poll::Ready(Ok(n))
+ }
+
+ /// Try to write data from an implementer of the [`Buf`] trait to an
+ /// [`AsyncWrite`], advancing the buffer's internal cursor.
+ ///
+ /// This function will use [vectored writes] when the [`AsyncWrite`] supports
+ /// vectored writes.
+ ///
+ /// # Examples
+ ///
+ /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
+ /// [`Buf`]:
+ ///
+ /// ```no_run
+ /// use tokio_util::io::poll_write_buf;
+ /// use tokio::io;
+ /// use tokio::fs::File;
+ ///
+ /// use bytes::Buf;
+ /// use std::io::Cursor;
+ /// use std::pin::Pin;
+ /// use futures::future::poll_fn;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let mut file = File::create("foo.txt").await?;
+ /// let mut buf = Cursor::new(b"data to write");
+ ///
+ /// // Loop until the entire contents of the buffer are written to
+ /// // the file.
+ /// while buf.has_remaining() {
+ /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
+ /// }
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`Buf`]: bytes::Buf
+ /// [`AsyncWrite`]: tokio::io::AsyncWrite
+ /// [`File`]: tokio::fs::File
+ /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
+ #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+ pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ const MAX_BUFS: usize = 64;
+
+ if !buf.has_remaining() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = if io.is_write_vectored() {
+ let mut slices = [IoSlice::new(&[]); MAX_BUFS];
+ let cnt = buf.chunks_vectored(&mut slices);
+ ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
+ } else {
+ ready!(io.poll_write(cx, buf.chunk()))?
+ };
+
+ buf.advance(n);
+
+ Poll::Ready(Ok(n))
+ }
+}
diff --git a/third_party/rust/tokio-util/src/loom.rs b/third_party/rust/tokio-util/src/loom.rs
new file mode 100644
index 0000000000..dd03feaba1
--- /dev/null
+++ b/third_party/rust/tokio-util/src/loom.rs
@@ -0,0 +1 @@
+pub(crate) use std::sync;
diff --git a/third_party/rust/tokio-util/src/net/mod.rs b/third_party/rust/tokio-util/src/net/mod.rs
new file mode 100644
index 0000000000..4817e10d0f
--- /dev/null
+++ b/third_party/rust/tokio-util/src/net/mod.rs
@@ -0,0 +1,97 @@
+//! TCP/UDP/Unix helpers for tokio.
+
+use crate::either::Either;
+use std::future::Future;
+use std::io::Result;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+#[cfg(unix)]
+pub mod unix;
+
+/// A trait for a listener: `TcpListener` and `UnixListener`.
+pub trait Listener {
+ /// The stream's type of this listener.
+ type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite;
+ /// The socket address type of this listener.
+ type Addr;
+
+ /// Polls to accept a new incoming connection to this listener.
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>>;
+
+ /// Accepts a new incoming connection from this listener.
+ fn accept(&mut self) -> ListenerAcceptFut<'_, Self>
+ where
+ Self: Sized,
+ {
+ ListenerAcceptFut { listener: self }
+ }
+
+ /// Returns the local address that this listener is bound to.
+ fn local_addr(&self) -> Result<Self::Addr>;
+}
+
+impl Listener for tokio::net::TcpListener {
+ type Io = tokio::net::TcpStream;
+ type Addr = std::net::SocketAddr;
+
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
+ Self::poll_accept(self, cx)
+ }
+
+ fn local_addr(&self) -> Result<Self::Addr> {
+ self.local_addr().map(Into::into)
+ }
+}
+
+/// Future for accepting a new connection from a listener.
+#[derive(Debug)]
+#[must_use = "futures do nothing unless you `.await` or poll them"]
+pub struct ListenerAcceptFut<'a, L> {
+ listener: &'a mut L,
+}
+
+impl<'a, L> Future for ListenerAcceptFut<'a, L>
+where
+ L: Listener,
+{
+ type Output = Result<(L::Io, L::Addr)>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ self.listener.poll_accept(cx)
+ }
+}
+
+impl<L, R> Either<L, R>
+where
+ L: Listener,
+ R: Listener,
+{
+ /// Accepts a new incoming connection from this listener.
+ pub async fn accept(&mut self) -> Result<Either<(L::Io, L::Addr), (R::Io, R::Addr)>> {
+ match self {
+ Either::Left(listener) => {
+ let (stream, addr) = listener.accept().await?;
+ Ok(Either::Left((stream, addr)))
+ }
+ Either::Right(listener) => {
+ let (stream, addr) = listener.accept().await?;
+ Ok(Either::Right((stream, addr)))
+ }
+ }
+ }
+
+ /// Returns the local address that this listener is bound to.
+ pub fn local_addr(&self) -> Result<Either<L::Addr, R::Addr>> {
+ match self {
+ Either::Left(listener) => {
+ let addr = listener.local_addr()?;
+ Ok(Either::Left(addr))
+ }
+ Either::Right(listener) => {
+ let addr = listener.local_addr()?;
+ Ok(Either::Right(addr))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/net/unix/mod.rs b/third_party/rust/tokio-util/src/net/unix/mod.rs
new file mode 100644
index 0000000000..0b522c90a3
--- /dev/null
+++ b/third_party/rust/tokio-util/src/net/unix/mod.rs
@@ -0,0 +1,18 @@
+//! Unix domain socket helpers.
+
+use super::Listener;
+use std::io::Result;
+use std::task::{Context, Poll};
+
+impl Listener for tokio::net::UnixListener {
+ type Io = tokio::net::UnixStream;
+ type Addr = tokio::net::unix::SocketAddr;
+
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
+ Self::poll_accept(self, cx)
+ }
+
+ fn local_addr(&self) -> Result<Self::Addr> {
+ self.local_addr().map(Into::into)
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token.rs b/third_party/rust/tokio-util/src/sync/cancellation_token.rs
new file mode 100644
index 0000000000..2a6ef392bd
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/cancellation_token.rs
@@ -0,0 +1,224 @@
+//! An asynchronously awaitable `CancellationToken`.
+//! The token allows to signal a cancellation request to one or more tasks.
+pub(crate) mod guard;
+mod tree_node;
+
+use crate::loom::sync::Arc;
+use core::future::Future;
+use core::pin::Pin;
+use core::task::{Context, Poll};
+
+use guard::DropGuard;
+use pin_project_lite::pin_project;
+
+/// A token which can be used to signal a cancellation request to one or more
+/// tasks.
+///
+/// Tasks can call [`CancellationToken::cancelled()`] in order to
+/// obtain a Future which will be resolved when cancellation is requested.
+///
+/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
+///
+/// # Examples
+///
+/// ```no_run
+/// use tokio::select;
+/// use tokio_util::sync::CancellationToken;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let token = CancellationToken::new();
+/// let cloned_token = token.clone();
+///
+/// let join_handle = tokio::spawn(async move {
+/// // Wait for either cancellation or a very long time
+/// select! {
+/// _ = cloned_token.cancelled() => {
+/// // The token was cancelled
+/// 5
+/// }
+/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
+/// 99
+/// }
+/// }
+/// });
+///
+/// tokio::spawn(async move {
+/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+/// token.cancel();
+/// });
+///
+/// assert_eq!(5, join_handle.await.unwrap());
+/// }
+/// ```
+pub struct CancellationToken {
+ inner: Arc<tree_node::TreeNode>,
+}
+
+pin_project! {
+ /// A Future that is resolved once the corresponding [`CancellationToken`]
+ /// is cancelled.
+ #[must_use = "futures do nothing unless polled"]
+ pub struct WaitForCancellationFuture<'a> {
+ cancellation_token: &'a CancellationToken,
+ #[pin]
+ future: tokio::sync::futures::Notified<'a>,
+ }
+}
+
+// ===== impl CancellationToken =====
+
+impl core::fmt::Debug for CancellationToken {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ f.debug_struct("CancellationToken")
+ .field("is_cancelled", &self.is_cancelled())
+ .finish()
+ }
+}
+
+impl Clone for CancellationToken {
+ fn clone(&self) -> Self {
+ tree_node::increase_handle_refcount(&self.inner);
+ CancellationToken {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+impl Drop for CancellationToken {
+ fn drop(&mut self) {
+ tree_node::decrease_handle_refcount(&self.inner);
+ }
+}
+
+impl Default for CancellationToken {
+ fn default() -> CancellationToken {
+ CancellationToken::new()
+ }
+}
+
+impl CancellationToken {
+ /// Creates a new CancellationToken in the non-cancelled state.
+ pub fn new() -> CancellationToken {
+ CancellationToken {
+ inner: Arc::new(tree_node::TreeNode::new()),
+ }
+ }
+
+ /// Creates a `CancellationToken` which will get cancelled whenever the
+ /// current token gets cancelled.
+ ///
+ /// If the current token is already cancelled, the child token will get
+ /// returned in cancelled state.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::select;
+ /// use tokio_util::sync::CancellationToken;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let token = CancellationToken::new();
+ /// let child_token = token.child_token();
+ ///
+ /// let join_handle = tokio::spawn(async move {
+ /// // Wait for either cancellation or a very long time
+ /// select! {
+ /// _ = child_token.cancelled() => {
+ /// // The token was cancelled
+ /// 5
+ /// }
+ /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
+ /// 99
+ /// }
+ /// }
+ /// });
+ ///
+ /// tokio::spawn(async move {
+ /// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+ /// token.cancel();
+ /// });
+ ///
+ /// assert_eq!(5, join_handle.await.unwrap());
+ /// }
+ /// ```
+ pub fn child_token(&self) -> CancellationToken {
+ CancellationToken {
+ inner: tree_node::child_node(&self.inner),
+ }
+ }
+
+ /// Cancel the [`CancellationToken`] and all child tokens which had been
+ /// derived from it.
+ ///
+ /// This will wake up all tasks which are waiting for cancellation.
+ ///
+ /// Be aware that cancellation is not an atomic operation. It is possible
+ /// for another thread running in parallel with a call to `cancel` to first
+ /// receive `true` from `is_cancelled` on one child node, and then receive
+ /// `false` from `is_cancelled` on another child node. However, once the
+ /// call to `cancel` returns, all child nodes have been fully cancelled.
+ pub fn cancel(&self) {
+ tree_node::cancel(&self.inner);
+ }
+
+ /// Returns `true` if the `CancellationToken` is cancelled.
+ pub fn is_cancelled(&self) -> bool {
+ tree_node::is_cancelled(&self.inner)
+ }
+
+ /// Returns a `Future` that gets fulfilled when cancellation is requested.
+ ///
+ /// The future will complete immediately if the token is already cancelled
+ /// when this method is called.
+ ///
+ /// # Cancel safety
+ ///
+ /// This method is cancel safe.
+ pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
+ WaitForCancellationFuture {
+ cancellation_token: self,
+ future: self.inner.notified(),
+ }
+ }
+
+ /// Creates a `DropGuard` for this token.
+ ///
+ /// Returned guard will cancel this token (and all its children) on drop
+ /// unless disarmed.
+ pub fn drop_guard(self) -> DropGuard {
+ DropGuard { inner: Some(self) }
+ }
+}
+
+// ===== impl WaitForCancellationFuture =====
+
+impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ f.debug_struct("WaitForCancellationFuture").finish()
+ }
+}
+
+impl<'a> Future for WaitForCancellationFuture<'a> {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+ let mut this = self.project();
+ loop {
+ if this.cancellation_token.is_cancelled() {
+ return Poll::Ready(());
+ }
+
+ // No wakeups can be lost here because there is always a call to
+ // `is_cancelled` between the creation of the future and the call to
+ // `poll`, and the code that sets the cancelled flag does so before
+ // waking the `Notified`.
+ if this.future.as_mut().poll(cx).is_pending() {
+ return Poll::Pending;
+ }
+
+ this.future.set(this.cancellation_token.inner.notified());
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs
new file mode 100644
index 0000000000..54ed7ea2ed
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs
@@ -0,0 +1,27 @@
+use crate::sync::CancellationToken;
+
+/// A wrapper for cancellation token which automatically cancels
+/// it on drop. It is created using `drop_guard` method on the `CancellationToken`.
+#[derive(Debug)]
+pub struct DropGuard {
+ pub(super) inner: Option<CancellationToken>,
+}
+
+impl DropGuard {
+ /// Returns stored cancellation token and removes this drop guard instance
+ /// (i.e. it will no longer cancel token). Other guards for this token
+ /// are not affected.
+ pub fn disarm(mut self) -> CancellationToken {
+ self.inner
+ .take()
+ .expect("`inner` can be only None in a destructor")
+ }
+}
+
+impl Drop for DropGuard {
+ fn drop(&mut self) {
+ if let Some(inner) = &self.inner {
+ inner.cancel();
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs
new file mode 100644
index 0000000000..b6cd698e23
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs
@@ -0,0 +1,373 @@
+//! This mod provides the logic for the inner tree structure of the CancellationToken.
+//!
+//! CancellationTokens are only light handles with references to TreeNode.
+//! All the logic is actually implemented in the TreeNode.
+//!
+//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of
+//! children.
+//!
+//! A TreeNode can receive the request to perform a cancellation through a CancellationToken.
+//! This cancellation request will cancel the node and all of its descendants.
+//!
+//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no
+//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the
+//! tree as small as possible.
+//!
+//! # Invariants
+//!
+//! Those invariants shall be true at any time.
+//!
+//! 1. A node that has no parents and no handles can no longer be cancelled.
+//! This is important during both cancellation and refcounting.
+//!
+//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A.
+//! This is important for deadlock safety, as it is used for lock order.
+//! Node B can only become the child of node A in two ways:
+//! - being created with `child_node()`, in which case it is trivially true that
+//! node A already existed when node B was created
+//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()`
+//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C
+//! was younger than A, therefore B is also younger than A.
+//!
+//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of
+//! node A. It is important to always restore that invariant before dropping the lock of a node.
+//!
+//! # Deadlock safety
+//!
+//! We always lock in the order of creation time. We can prove this through invariant #2.
+//! Specifically, through invariant #2, we know that we always have to lock a parent
+//! before its child.
+//!
+use crate::loom::sync::{Arc, Mutex, MutexGuard};
+
+/// A node of the cancellation tree structure
+///
+/// The actual data it holds is wrapped inside a mutex for synchronization.
+pub(crate) struct TreeNode {
+ inner: Mutex<Inner>,
+ waker: tokio::sync::Notify,
+}
+impl TreeNode {
+ pub(crate) fn new() -> Self {
+ Self {
+ inner: Mutex::new(Inner {
+ parent: None,
+ parent_idx: 0,
+ children: vec![],
+ is_cancelled: false,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ }
+ }
+
+ pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> {
+ self.waker.notified()
+ }
+}
+
+/// The data contained inside a TreeNode.
+///
+/// This struct exists so that the data of the node can be wrapped
+/// in a Mutex.
+struct Inner {
+ parent: Option<Arc<TreeNode>>,
+ parent_idx: usize,
+ children: Vec<Arc<TreeNode>>,
+ is_cancelled: bool,
+ num_handles: usize,
+}
+
+/// Returns whether or not the node is cancelled
+pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool {
+ node.inner.lock().unwrap().is_cancelled
+}
+
+/// Creates a child node
+pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> {
+ let mut locked_parent = parent.inner.lock().unwrap();
+
+ // Do not register as child if we are already cancelled.
+ // Cancelled trees can never be uncancelled and therefore
+ // need no connection to parents or children any more.
+ if locked_parent.is_cancelled {
+ return Arc::new(TreeNode {
+ inner: Mutex::new(Inner {
+ parent: None,
+ parent_idx: 0,
+ children: vec![],
+ is_cancelled: true,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ });
+ }
+
+ let child = Arc::new(TreeNode {
+ inner: Mutex::new(Inner {
+ parent: Some(parent.clone()),
+ parent_idx: locked_parent.children.len(),
+ children: vec![],
+ is_cancelled: false,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ });
+
+ locked_parent.children.push(child.clone());
+
+ child
+}
+
+/// Disconnects the given parent from all of its children.
+///
+/// Takes a reference to [Inner] to make sure the parent is already locked.
+fn disconnect_children(node: &mut Inner) {
+ for child in std::mem::take(&mut node.children) {
+ let mut locked_child = child.inner.lock().unwrap();
+ locked_child.parent_idx = 0;
+ locked_child.parent = None;
+ }
+}
+
+/// Figures out the parent of the node and locks the node and its parent atomically.
+///
+/// The basic principle of preventing deadlocks in the tree is
+/// that we always lock the parent first, and then the child.
+/// For more info look at *deadlock safety* and *invariant #2*.
+///
+/// Sadly, it's impossible to figure out the parent of a node without
+/// locking it. To then achieve locking order consistency, the node
+/// has to be unlocked before the parent gets locked.
+/// This leaves a small window where we already assume that we know the parent,
+/// but neither the parent nor the node is locked. Therefore, the parent could change.
+///
+/// To prevent that this problem leaks into the rest of the code, it is abstracted
+/// in this function.
+///
+/// The locked child and optionally its locked parent, if a parent exists, get passed
+/// to the `func` argument via (node, None) or (node, Some(parent)).
+fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret
+where
+ F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret,
+{
+ let mut potential_parent = {
+ let locked_node = node.inner.lock().unwrap();
+ match locked_node.parent.clone() {
+ Some(parent) => parent,
+ // If we locked the node and its parent is `None`, we are in a valid state
+ // and can return.
+ None => return func(locked_node, None),
+ }
+ };
+
+ loop {
+ // Deadlock safety:
+ //
+ // Due to invariant #2, we know that we have to lock the parent first, and then the child.
+ // This is true even if the potential_parent is no longer the current parent or even its
+ // sibling, as the invariant still holds.
+ let locked_parent = potential_parent.inner.lock().unwrap();
+ let locked_node = node.inner.lock().unwrap();
+
+ let actual_parent = match locked_node.parent.clone() {
+ Some(parent) => parent,
+ // If we locked the node and its parent is `None`, we are in a valid state
+ // and can return.
+ None => {
+ // Was the wrong parent, so unlock it before calling `func`
+ drop(locked_parent);
+ return func(locked_node, None);
+ }
+ };
+
+ // Loop until we managed to lock both the node and its parent
+ if Arc::ptr_eq(&actual_parent, &potential_parent) {
+ return func(locked_node, Some(locked_parent));
+ }
+
+ // Drop locked_parent before reassigning to potential_parent,
+ // as potential_parent is borrowed in it
+ drop(locked_node);
+ drop(locked_parent);
+
+ potential_parent = actual_parent;
+ }
+}
+
+/// Moves all children from `node` to `parent`.
+///
+/// `parent` MUST have been a parent of the node when they both got locked,
+/// otherwise there is a potential for a deadlock as invariant #2 would be violated.
+///
+/// To aquire the locks for node and parent, use [with_locked_node_and_parent].
+fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) {
+ // Pre-allocate in the parent, for performance
+ parent.children.reserve(node.children.len());
+
+ for child in std::mem::take(&mut node.children) {
+ {
+ let mut child_locked = child.inner.lock().unwrap();
+ child_locked.parent = node.parent.clone();
+ child_locked.parent_idx = parent.children.len();
+ }
+ parent.children.push(child);
+ }
+}
+
+/// Removes a child from the parent.
+///
+/// `parent` MUST be the parent of `node`.
+/// To aquire the locks for node and parent, use [with_locked_node_and_parent].
+fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) {
+ // Query the position from where to remove a node
+ let pos = node.parent_idx;
+ node.parent = None;
+ node.parent_idx = 0;
+
+ // Unlock node, so that only one child at a time is locked.
+ // Otherwise we would violate the lock order (see 'deadlock safety') as we
+ // don't know the creation order of the child nodes
+ drop(node);
+
+ // If `node` is the last element in the list, we don't need any swapping
+ if parent.children.len() == pos + 1 {
+ parent.children.pop().unwrap();
+ } else {
+ // If `node` is not the last element in the list, we need to
+ // replace it with the last element
+ let replacement_child = parent.children.pop().unwrap();
+ replacement_child.inner.lock().unwrap().parent_idx = pos;
+ parent.children[pos] = replacement_child;
+ }
+
+ let len = parent.children.len();
+ if 4 * len <= parent.children.capacity() {
+ // equal to:
+ // parent.children.shrink_to(2 * len);
+ // but shrink_to was not yet stabilized in our minimal compatible version
+ let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len));
+ parent.children.extend(old_children);
+ }
+}
+
+/// Increases the reference count of handles.
+pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) {
+ let mut locked_node = node.inner.lock().unwrap();
+
+ // Once no handles are left over, the node gets detached from the tree.
+ // There should never be a new handle once all handles are dropped.
+ assert!(locked_node.num_handles > 0);
+
+ locked_node.num_handles += 1;
+}
+
+/// Decreases the reference count of handles.
+///
+/// Once no handle is left, we can remove the node from the
+/// tree and connect its parent directly to its children.
+pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) {
+ let num_handles = {
+ let mut locked_node = node.inner.lock().unwrap();
+ locked_node.num_handles -= 1;
+ locked_node.num_handles
+ };
+
+ if num_handles == 0 {
+ with_locked_node_and_parent(node, |mut node, parent| {
+ // Remove the node from the tree
+ match parent {
+ Some(mut parent) => {
+ // As we want to remove ourselves from the tree,
+ // we have to move the children to the parent, so that
+ // they still receive the cancellation event without us.
+ // Moving them does not violate invariant #1.
+ move_children_to_parent(&mut node, &mut parent);
+
+ // Remove the node from the parent
+ remove_child(&mut parent, node);
+ }
+ None => {
+ // Due to invariant #1, we can assume that our
+ // children can no longer be cancelled through us.
+ // (as we now have neither a parent nor handles)
+ // Therefore we can disconnect them.
+ disconnect_children(&mut node);
+ }
+ }
+ });
+ }
+}
+
+/// Cancels a node and its children.
+pub(crate) fn cancel(node: &Arc<TreeNode>) {
+ let mut locked_node = node.inner.lock().unwrap();
+
+ if locked_node.is_cancelled {
+ return;
+ }
+
+ // One by one, adopt grandchildren and then cancel and detach the child
+ while let Some(child) = locked_node.children.pop() {
+ // This can't deadlock because the mutex we are already
+ // holding is the parent of child.
+ let mut locked_child = child.inner.lock().unwrap();
+
+ // Detach the child from node
+ // No need to modify node.children, as the child already got removed with `.pop`
+ locked_child.parent = None;
+ locked_child.parent_idx = 0;
+
+ // If child is already cancelled, detaching is enough
+ if locked_child.is_cancelled {
+ continue;
+ }
+
+ // Cancel or adopt grandchildren
+ while let Some(grandchild) = locked_child.children.pop() {
+ // This can't deadlock because the two mutexes we are already
+ // holding is the parent and grandparent of grandchild.
+ let mut locked_grandchild = grandchild.inner.lock().unwrap();
+
+ // Detach the grandchild
+ locked_grandchild.parent = None;
+ locked_grandchild.parent_idx = 0;
+
+ // If grandchild is already cancelled, detaching is enough
+ if locked_grandchild.is_cancelled {
+ continue;
+ }
+
+ // For performance reasons, only adopt grandchildren that have children.
+ // Otherwise, just cancel them right away, no need for another iteration.
+ if locked_grandchild.children.is_empty() {
+ // Cancel the grandchild
+ locked_grandchild.is_cancelled = true;
+ locked_grandchild.children = Vec::new();
+ drop(locked_grandchild);
+ grandchild.waker.notify_waiters();
+ } else {
+ // Otherwise, adopt grandchild
+ locked_grandchild.parent = Some(node.clone());
+ locked_grandchild.parent_idx = locked_node.children.len();
+ drop(locked_grandchild);
+ locked_node.children.push(grandchild);
+ }
+ }
+
+ // Cancel the child
+ locked_child.is_cancelled = true;
+ locked_child.children = Vec::new();
+ drop(locked_child);
+ child.waker.notify_waiters();
+
+ // Now the child is cancelled and detached and all its children are adopted.
+ // Just continue until all (including adopted) children are cancelled and detached.
+ }
+
+ // Cancel the node itself.
+ locked_node.is_cancelled = true;
+ locked_node.children = Vec::new();
+ drop(locked_node);
+ node.waker.notify_waiters();
+}
diff --git a/third_party/rust/tokio-util/src/sync/mod.rs b/third_party/rust/tokio-util/src/sync/mod.rs
new file mode 100644
index 0000000000..de392f0bb1
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/mod.rs
@@ -0,0 +1,13 @@
+//! Synchronization primitives
+
+mod cancellation_token;
+pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancellationFuture};
+
+mod mpsc;
+pub use mpsc::{PollSendError, PollSender};
+
+mod poll_semaphore;
+pub use poll_semaphore::PollSemaphore;
+
+mod reusable_box;
+pub use reusable_box::ReusableBoxFuture;
diff --git a/third_party/rust/tokio-util/src/sync/mpsc.rs b/third_party/rust/tokio-util/src/sync/mpsc.rs
new file mode 100644
index 0000000000..34a47c1891
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/mpsc.rs
@@ -0,0 +1,283 @@
+use futures_sink::Sink;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{fmt, mem};
+use tokio::sync::mpsc::OwnedPermit;
+use tokio::sync::mpsc::Sender;
+
+use super::ReusableBoxFuture;
+
+/// Error returned by the `PollSender` when the channel is closed.
+#[derive(Debug)]
+pub struct PollSendError<T>(Option<T>);
+
+impl<T> PollSendError<T> {
+ /// Consumes the stored value, if any.
+ ///
+ /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
+ /// that the caller attempted to send. Otherwise, it will be `None`.
+ pub fn into_inner(self) -> Option<T> {
+ self.0
+ }
+}
+
+impl<T> fmt::Display for PollSendError<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+}
+
+impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
+
+#[derive(Debug)]
+enum State<T> {
+ Idle(Sender<T>),
+ Acquiring,
+ ReadyToSend(OwnedPermit<T>),
+ Closed,
+}
+
+/// A wrapper around [`mpsc::Sender`] that can be polled.
+///
+/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
+#[derive(Debug)]
+pub struct PollSender<T> {
+ sender: Option<Sender<T>>,
+ state: State<T>,
+ acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>,
+}
+
+// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
+// there's capacity for a send to complete.
+//
+// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
+// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
+async fn make_acquire_future<T>(
+ data: Option<Sender<T>>,
+) -> Result<OwnedPermit<T>, PollSendError<T>> {
+ match data {
+ Some(sender) => sender
+ .reserve_owned()
+ .await
+ .map_err(|_| PollSendError(None)),
+ None => unreachable!("this future should not be pollable in this state"),
+ }
+}
+
+impl<T: Send + 'static> PollSender<T> {
+ /// Creates a new `PollSender`.
+ pub fn new(sender: Sender<T>) -> Self {
+ Self {
+ sender: Some(sender.clone()),
+ state: State::Idle(sender),
+ acquire: ReusableBoxFuture::new(make_acquire_future(None)),
+ }
+ }
+
+ fn take_state(&mut self) -> State<T> {
+ mem::replace(&mut self.state, State::Closed)
+ }
+
+ /// Attempts to prepare the sender to receive a value.
+ ///
+ /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
+ /// `send_item`.
+ ///
+ /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
+ /// by reserving a slot in the channel for the item to be sent. If this method returns
+ /// `Poll::Pending`, the current task is registered to be notified (via
+ /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
+ loop {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(sender) => {
+ // Start trying to acquire a permit to reserve a slot for our send, and
+ // immediately loop back around to poll it the first time.
+ self.acquire.set(make_acquire_future(Some(sender)));
+ (None, State::Acquiring)
+ }
+ State::Acquiring => match self.acquire.poll(cx) {
+ // Channel has capacity.
+ Poll::Ready(Ok(permit)) => {
+ (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
+ }
+ // Channel is closed.
+ Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
+ // Channel doesn't have capacity yet, so we need to wait.
+ Poll::Pending => (Some(Poll::Pending), State::Acquiring),
+ },
+ // We're closed, either by choice or because the underlying sender was closed.
+ s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
+ // We're already ready to send an item.
+ s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
+ };
+
+ self.state = next_state;
+ if let Some(result) = result {
+ return result;
+ }
+ }
+ }
+
+ /// Sends an item to the channel.
+ ///
+ /// Before calling `send_item`, `poll_reserve` must be called with a successful return
+ /// value of `Poll::Ready(Ok(()))`.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ ///
+ /// # Panics
+ ///
+ /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
+ /// will panic.
+ pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(_) | State::Acquiring => {
+ panic!("`send_item` called without first calling `poll_reserve`")
+ }
+ // We have a permit to send our item, so go ahead, which gets us our sender back.
+ State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
+ // We're closed, either by choice or because the underlying sender was closed.
+ State::Closed => (Err(PollSendError(Some(value))), State::Closed),
+ };
+
+ // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
+ self.state = if self.sender.is_some() {
+ next_state
+ } else {
+ State::Closed
+ };
+ result
+ }
+
+ /// Checks whether this sender is been closed.
+ ///
+ /// The underlying channel that this sender was wrapping may still be open.
+ pub fn is_closed(&self) -> bool {
+ matches!(self.state, State::Closed) || self.sender.is_none()
+ }
+
+ /// Gets a reference to the `Sender` of the underlying channel.
+ ///
+ /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
+ /// was wrapping may still be open.
+ pub fn get_ref(&self) -> Option<&Sender<T>> {
+ self.sender.as_ref()
+ }
+
+ /// Closes this sender.
+ ///
+ /// No more messages will be able to be sent from this sender, but the underlying channel will
+ /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
+ ///
+ /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
+ /// to `send_item` in order to consume the reserved slot. After that, no further sends will be
+ /// possible. If you do not intend to send another item, you can release the reserved slot back
+ /// to the underlying sender by calling [`abort_send`].
+ ///
+ /// [`abort_send`]: crate::sync::PollSender::abort_send
+ /// [`Receiver`]: tokio::sync::mpsc::Receiver
+ pub fn close(&mut self) {
+ // Mark ourselves officially closed by dropping our main sender.
+ self.sender = None;
+
+ // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
+ // transition to the closed state. Otherwise, leave the existing permit in place for the
+ // caller if they want to complete the send.
+ match self.state {
+ State::Idle(_) => self.state = State::Closed,
+ State::Acquiring => {
+ self.acquire.set(make_acquire_future(None));
+ self.state = State::Closed;
+ }
+ _ => {}
+ }
+ }
+
+ /// Aborts the current in-progress send, if any.
+ ///
+ /// Returns `true` if a send was aborted. If the sender was closed prior to calling
+ /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
+ /// ready to attempt another send.
+ pub fn abort_send(&mut self) -> bool {
+ // We may have been closed in the meantime, after a call to `poll_reserve` already
+ // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
+ // closed state when we actually abort a send, rather than resetting ourselves back to idle.
+
+ let (result, next_state) = match self.take_state() {
+ // We're currently trying to reserve a slot to send into.
+ State::Acquiring => {
+ // Replacing the future drops the in-flight one.
+ self.acquire.set(make_acquire_future(None));
+
+ // If we haven't closed yet, we have to clone our stored sender since we have no way
+ // to get it back from the acquire future we just dropped.
+ let state = match self.sender.clone() {
+ Some(sender) => State::Idle(sender),
+ None => State::Closed,
+ };
+ (true, state)
+ }
+ // We got the permit. If we haven't closed yet, get the sender back.
+ State::ReadyToSend(permit) => {
+ let state = if self.sender.is_some() {
+ State::Idle(permit.release())
+ } else {
+ State::Closed
+ };
+ (true, state)
+ }
+ s => (false, s),
+ };
+
+ self.state = next_state;
+ result
+ }
+}
+
+impl<T> Clone for PollSender<T> {
+ /// Clones this `PollSender`.
+ ///
+ /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
+ fn clone(&self) -> PollSender<T> {
+ let (sender, state) = match self.sender.clone() {
+ Some(sender) => (Some(sender.clone()), State::Idle(sender)),
+ None => (None, State::Closed),
+ };
+
+ Self {
+ sender,
+ state,
+ // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
+ // compatible with the transitive bounds required by `Sender<T>`.
+ acquire: ReusableBoxFuture::new(async { unreachable!() }),
+ }
+ }
+}
+
+impl<T: Send + 'static> Sink<T> for PollSender<T> {
+ type Error = PollSendError<T>;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).poll_reserve(cx)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+ Pin::into_inner(self).send_item(item)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).close();
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/poll_semaphore.rs b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs
new file mode 100644
index 0000000000..d0b1dedc27
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs
@@ -0,0 +1,136 @@
+use futures_core::{ready, Stream};
+use std::fmt;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
+
+use super::ReusableBoxFuture;
+
+/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
+///
+/// [`Semaphore`]: tokio::sync::Semaphore
+pub struct PollSemaphore {
+ semaphore: Arc<Semaphore>,
+ permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>,
+}
+
+impl PollSemaphore {
+ /// Create a new `PollSemaphore`.
+ pub fn new(semaphore: Arc<Semaphore>) -> Self {
+ Self {
+ semaphore,
+ permit_fut: None,
+ }
+ }
+
+ /// Closes the semaphore.
+ pub fn close(&self) {
+ self.semaphore.close()
+ }
+
+ /// Obtain a clone of the inner semaphore.
+ pub fn clone_inner(&self) -> Arc<Semaphore> {
+ self.semaphore.clone()
+ }
+
+ /// Get back the inner semaphore.
+ pub fn into_inner(self) -> Arc<Semaphore> {
+ self.semaphore
+ }
+
+ /// Poll to acquire a permit from the semaphore.
+ ///
+ /// This can return the following values:
+ ///
+ /// - `Poll::Pending` if a permit is not currently available.
+ /// - `Poll::Ready(Some(permit))` if a permit was acquired.
+ /// - `Poll::Ready(None)` if the semaphore has been closed.
+ ///
+ /// When this method returns `Poll::Pending`, the current task is scheduled
+ /// to receive a wakeup when a permit becomes available, or when the
+ /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
+ /// the `Waker` from the `Context` passed to the most recent call is
+ /// scheduled to receive a wakeup.
+ pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ let permit_future = match self.permit_fut.as_mut() {
+ Some(fut) => fut,
+ None => {
+ // avoid allocations completely if we can grab a permit immediately
+ match Arc::clone(&self.semaphore).try_acquire_owned() {
+ Ok(permit) => return Poll::Ready(Some(permit)),
+ Err(TryAcquireError::Closed) => return Poll::Ready(None),
+ Err(TryAcquireError::NoPermits) => {}
+ }
+
+ let next_fut = Arc::clone(&self.semaphore).acquire_owned();
+ self.permit_fut
+ .get_or_insert(ReusableBoxFuture::new(next_fut))
+ }
+ };
+
+ let result = ready!(permit_future.poll(cx));
+
+ let next_fut = Arc::clone(&self.semaphore).acquire_owned();
+ permit_future.set(next_fut);
+
+ match result {
+ Ok(permit) => Poll::Ready(Some(permit)),
+ Err(_closed) => {
+ self.permit_fut = None;
+ Poll::Ready(None)
+ }
+ }
+ }
+
+ /// Returns the current number of available permits.
+ ///
+ /// This is equivalent to the [`Semaphore::available_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
+ pub fn available_permits(&self) -> usize {
+ self.semaphore.available_permits()
+ }
+
+ /// Adds `n` new permits to the semaphore.
+ ///
+ /// The maximum number of permits is `usize::MAX >> 3`, and this function
+ /// will panic if the limit is exceeded.
+ ///
+ /// This is equivalent to the [`Semaphore::add_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
+ pub fn add_permits(&self, n: usize) {
+ self.semaphore.add_permits(n);
+ }
+}
+
+impl Stream for PollSemaphore {
+ type Item = OwnedSemaphorePermit;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ Pin::into_inner(self).poll_acquire(cx)
+ }
+}
+
+impl Clone for PollSemaphore {
+ fn clone(&self) -> PollSemaphore {
+ PollSemaphore::new(self.clone_inner())
+ }
+}
+
+impl fmt::Debug for PollSemaphore {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PollSemaphore")
+ .field("semaphore", &self.semaphore)
+ .finish()
+ }
+}
+
+impl AsRef<Semaphore> for PollSemaphore {
+ fn as_ref(&self) -> &Semaphore {
+ &*self.semaphore
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/reusable_box.rs b/third_party/rust/tokio-util/src/sync/reusable_box.rs
new file mode 100644
index 0000000000..3204207db7
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/reusable_box.rs
@@ -0,0 +1,148 @@
+use std::alloc::Layout;
+use std::future::Future;
+use std::panic::AssertUnwindSafe;
+use std::pin::Pin;
+use std::ptr::{self, NonNull};
+use std::task::{Context, Poll};
+use std::{fmt, panic};
+
+/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`.
+///
+/// This type lets you replace the future stored in the box without
+/// reallocating when the size and alignment permits this.
+pub struct ReusableBoxFuture<'a, T> {
+ boxed: NonNull<dyn Future<Output = T> + Send + 'a>,
+}
+
+impl<'a, T> ReusableBoxFuture<'a, T> {
+ /// Create a new `ReusableBoxFuture<T>` containing the provided future.
+ pub fn new<F>(future: F) -> Self
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ let boxed: Box<dyn Future<Output = T> + Send + 'a> = Box::new(future);
+
+ let boxed = NonNull::from(Box::leak(boxed));
+
+ Self { boxed }
+ }
+
+ /// Replace the future currently stored in this box.
+ ///
+ /// This reallocates if and only if the layout of the provided future is
+ /// different from the layout of the currently stored future.
+ pub fn set<F>(&mut self, future: F)
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ if let Err(future) = self.try_set(future) {
+ *self = Self::new(future);
+ }
+ }
+
+ /// Replace the future currently stored in this box.
+ ///
+ /// This function never reallocates, but returns an error if the provided
+ /// future has a different size or alignment from the currently stored
+ /// future.
+ pub fn try_set<F>(&mut self, future: F) -> Result<(), F>
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ // SAFETY: The pointer is not dangling.
+ let self_layout = {
+ let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() };
+ Layout::for_value(dyn_future)
+ };
+
+ if Layout::new::<F>() == self_layout {
+ // SAFETY: We just checked that the layout of F is correct.
+ unsafe {
+ self.set_same_layout(future);
+ }
+
+ Ok(())
+ } else {
+ Err(future)
+ }
+ }
+
+ /// Set the current future.
+ ///
+ /// # Safety
+ ///
+ /// This function requires that the layout of the provided future is the
+ /// same as `self.layout`.
+ unsafe fn set_same_layout<F>(&mut self, future: F)
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ // Drop the existing future, catching any panics.
+ let result = panic::catch_unwind(AssertUnwindSafe(|| {
+ ptr::drop_in_place(self.boxed.as_ptr());
+ }));
+
+ // Overwrite the future behind the pointer. This is safe because the
+ // allocation was allocated with the same size and alignment as the type F.
+ let self_ptr: *mut F = self.boxed.as_ptr() as *mut F;
+ ptr::write(self_ptr, future);
+
+ // Update the vtable of self.boxed. The pointer is not null because we
+ // just got it from self.boxed, which is not null.
+ self.boxed = NonNull::new_unchecked(self_ptr);
+
+ // If the old future's destructor panicked, resume unwinding.
+ match result {
+ Ok(()) => {}
+ Err(payload) => {
+ panic::resume_unwind(payload);
+ }
+ }
+ }
+
+ /// Get a pinned reference to the underlying future.
+ pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
+ // SAFETY: The user of this box cannot move the box, and we do not move it
+ // either.
+ unsafe { Pin::new_unchecked(self.boxed.as_mut()) }
+ }
+
+ /// Poll the future stored inside this box.
+ pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
+ self.get_pin().poll(cx)
+ }
+}
+
+impl<T> Future for ReusableBoxFuture<'_, T> {
+ type Output = T;
+
+ /// Poll the future stored inside this box.
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
+ Pin::into_inner(self).get_pin().poll(cx)
+ }
+}
+
+// The future stored inside ReusableBoxFuture<'_, T> must be Send.
+unsafe impl<T> Send for ReusableBoxFuture<'_, T> {}
+
+// The only method called on self.boxed is poll, which takes &mut self, so this
+// struct being Sync does not permit any invalid access to the Future, even if
+// the future is not Sync.
+unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {}
+
+// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type.
+impl<T> Unpin for ReusableBoxFuture<'_, T> {}
+
+impl<T> Drop for ReusableBoxFuture<'_, T> {
+ fn drop(&mut self) {
+ unsafe {
+ drop(Box::from_raw(self.boxed.as_ptr()));
+ }
+ }
+}
+
+impl<T> fmt::Debug for ReusableBoxFuture<'_, T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReusableBoxFuture").finish()
+ }
+}
diff --git a/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs
new file mode 100644
index 0000000000..e9c9f3dd98
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs
@@ -0,0 +1,155 @@
+use crate::sync::CancellationToken;
+
+use loom::{future::block_on, thread};
+use tokio_test::assert_ok;
+
+#[test]
+fn cancel_token() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ token1.cancelled().await;
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ token.cancel();
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ });
+}
+
+#[test]
+fn cancel_with_child() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+ let token2 = token.clone();
+ let child_token = token.child_token();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ token1.cancelled().await;
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ block_on(async {
+ child_token.cancelled().await;
+ });
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_token_no_child() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+ let token2 = token.clone();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ drop(token2);
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(token);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_token_with_childs() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let child_token1 = token1.child_token();
+ let child_token2 = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ drop(child_token1);
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(child_token2);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_and_cancel_token() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let token2 = token1.clone();
+ let child_token = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(child_token);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn cancel_parent_and_child() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let token2 = token1.clone();
+ let child_token = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ child_token.cancel();
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
diff --git a/third_party/rust/tokio-util/src/sync/tests/mod.rs b/third_party/rust/tokio-util/src/sync/tests/mod.rs
new file mode 100644
index 0000000000..8b13789179
--- /dev/null
+++ b/third_party/rust/tokio-util/src/sync/tests/mod.rs
@@ -0,0 +1 @@
+
diff --git a/third_party/rust/tokio-util/src/task/mod.rs b/third_party/rust/tokio-util/src/task/mod.rs
new file mode 100644
index 0000000000..5aa33df2dc
--- /dev/null
+++ b/third_party/rust/tokio-util/src/task/mod.rs
@@ -0,0 +1,4 @@
+//! Extra utilities for spawning tasks
+
+mod spawn_pinned;
+pub use spawn_pinned::LocalPoolHandle;
diff --git a/third_party/rust/tokio-util/src/task/spawn_pinned.rs b/third_party/rust/tokio-util/src/task/spawn_pinned.rs
new file mode 100644
index 0000000000..6f553e9d07
--- /dev/null
+++ b/third_party/rust/tokio-util/src/task/spawn_pinned.rs
@@ -0,0 +1,307 @@
+use futures_util::future::{AbortHandle, Abortable};
+use std::fmt;
+use std::fmt::{Debug, Formatter};
+use std::future::Future;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use tokio::runtime::Builder;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
+use tokio::sync::oneshot;
+use tokio::task::{spawn_local, JoinHandle, LocalSet};
+
+/// A handle to a local pool, used for spawning `!Send` tasks.
+#[derive(Clone)]
+pub struct LocalPoolHandle {
+ pool: Arc<LocalPool>,
+}
+
+impl LocalPoolHandle {
+ /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
+ /// pool via [`LocalPoolHandle::spawn_pinned`].
+ ///
+ /// # Panics
+ /// Panics if the pool size is less than one.
+ pub fn new(pool_size: usize) -> LocalPoolHandle {
+ assert!(pool_size > 0);
+
+ let workers = (0..pool_size)
+ .map(|_| LocalWorkerHandle::new_worker())
+ .collect();
+
+ let pool = Arc::new(LocalPool { workers });
+
+ LocalPoolHandle { pool }
+ }
+
+ /// Spawn a task onto a worker thread and pin it there so it can't be moved
+ /// off of the thread. Note that the future is not [`Send`], but the
+ /// [`FnOnce`] which creates it is.
+ ///
+ /// # Examples
+ /// ```
+ /// use std::rc::Rc;
+ /// use tokio_util::task::LocalPoolHandle;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// // Create the local pool
+ /// let pool = LocalPoolHandle::new(1);
+ ///
+ /// // Spawn a !Send future onto the pool and await it
+ /// let output = pool
+ /// .spawn_pinned(|| {
+ /// // Rc is !Send + !Sync
+ /// let local_data = Rc::new("test");
+ ///
+ /// // This future holds an Rc, so it is !Send
+ /// async move { local_data.to_string() }
+ /// })
+ /// .await
+ /// .unwrap();
+ ///
+ /// assert_eq!(output, "test");
+ /// }
+ /// ```
+ pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
+ where
+ F: FnOnce() -> Fut,
+ F: Send + 'static,
+ Fut: Future + 'static,
+ Fut::Output: Send + 'static,
+ {
+ self.pool.spawn_pinned(create_task)
+ }
+}
+
+impl Debug for LocalPoolHandle {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.write_str("LocalPoolHandle")
+ }
+}
+
+struct LocalPool {
+ workers: Vec<LocalWorkerHandle>,
+}
+
+impl LocalPool {
+ /// Spawn a `?Send` future onto a worker
+ fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
+ where
+ F: FnOnce() -> Fut,
+ F: Send + 'static,
+ Fut: Future + 'static,
+ Fut::Output: Send + 'static,
+ {
+ let (sender, receiver) = oneshot::channel();
+
+ let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
+ let worker_spawner = worker.spawner.clone();
+
+ // Spawn a future onto the worker's runtime so we can immediately return
+ // a join handle.
+ worker.runtime_handle.spawn(async move {
+ // Move the job guard into the task
+ let _job_guard = job_guard;
+
+ // Propagate aborts via Abortable/AbortHandle
+ let (abort_handle, abort_registration) = AbortHandle::new_pair();
+ let _abort_guard = AbortGuard(abort_handle);
+
+ // Inside the future we can't run spawn_local yet because we're not
+ // in the context of a LocalSet. We need to send create_task to the
+ // LocalSet task for spawning.
+ let spawn_task = Box::new(move || {
+ // Once we're in the LocalSet context we can call spawn_local
+ let join_handle =
+ spawn_local(
+ async move { Abortable::new(create_task(), abort_registration).await },
+ );
+
+ // Send the join handle back to the spawner. If sending fails,
+ // we assume the parent task was canceled, so cancel this task
+ // as well.
+ if let Err(join_handle) = sender.send(join_handle) {
+ join_handle.abort()
+ }
+ });
+
+ // Send the callback to the LocalSet task
+ if let Err(e) = worker_spawner.send(spawn_task) {
+ // Propagate the error as a panic in the join handle.
+ panic!("Failed to send job to worker: {}", e);
+ }
+
+ // Wait for the task's join handle
+ let join_handle = match receiver.await {
+ Ok(handle) => handle,
+ Err(e) => {
+ // We sent the task successfully, but failed to get its
+ // join handle... We assume something happened to the worker
+ // and the task was not spawned. Propagate the error as a
+ // panic in the join handle.
+ panic!("Worker failed to send join handle: {}", e);
+ }
+ };
+
+ // Wait for the task to complete
+ let join_result = join_handle.await;
+
+ match join_result {
+ Ok(Ok(output)) => output,
+ Ok(Err(_)) => {
+ // Pinned task was aborted. But that only happens if this
+ // task is aborted. So this is an impossible branch.
+ unreachable!(
+ "Reaching this branch means this task was previously \
+ aborted but it continued running anyways"
+ )
+ }
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else if e.is_cancelled() {
+ // No one else should have the join handle, so this is
+ // unexpected. Forward this error as a panic in the join
+ // handle.
+ panic!("spawn_pinned task was canceled: {}", e);
+ } else {
+ // Something unknown happened (not a panic or
+ // cancellation). Forward this error as a panic in the
+ // join handle.
+ panic!("spawn_pinned task failed: {}", e);
+ }
+ }
+ }
+ })
+ }
+
+ /// Find the worker with the least number of tasks, increment its task
+ /// count, and return its handle. Make sure to actually spawn a task on
+ /// the worker so the task count is kept consistent with load.
+ ///
+ /// A job count guard is also returned to ensure the task count gets
+ /// decremented when the job is done.
+ fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
+ loop {
+ let (worker, task_count) = self
+ .workers
+ .iter()
+ .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
+ .min_by_key(|&(_, count)| count)
+ .expect("There must be more than one worker");
+
+ // Make sure the task count hasn't changed since when we choose this
+ // worker. Otherwise, restart the search.
+ if worker
+ .task_count
+ .compare_exchange(
+ task_count,
+ task_count + 1,
+ Ordering::SeqCst,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
+ }
+ }
+ }
+}
+
+/// Automatically decrements a worker's job count when a job finishes (when
+/// this gets dropped).
+struct JobCountGuard(Arc<AtomicUsize>);
+
+impl Drop for JobCountGuard {
+ fn drop(&mut self) {
+ // Decrement the job count
+ let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
+ debug_assert!(previous_value >= 1);
+ }
+}
+
+/// Calls abort on the handle when dropped.
+struct AbortGuard(AbortHandle);
+
+impl Drop for AbortGuard {
+ fn drop(&mut self) {
+ self.0.abort();
+ }
+}
+
+type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
+
+struct LocalWorkerHandle {
+ runtime_handle: tokio::runtime::Handle,
+ spawner: UnboundedSender<PinnedFutureSpawner>,
+ task_count: Arc<AtomicUsize>,
+}
+
+impl LocalWorkerHandle {
+ /// Create a new worker for executing pinned tasks
+ fn new_worker() -> LocalWorkerHandle {
+ let (sender, receiver) = unbounded_channel();
+ let runtime = Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .expect("Failed to start a pinned worker thread runtime");
+ let runtime_handle = runtime.handle().clone();
+ let task_count = Arc::new(AtomicUsize::new(0));
+ let task_count_clone = Arc::clone(&task_count);
+
+ std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
+
+ LocalWorkerHandle {
+ runtime_handle,
+ spawner: sender,
+ task_count,
+ }
+ }
+
+ fn run(
+ runtime: tokio::runtime::Runtime,
+ mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
+ task_count: Arc<AtomicUsize>,
+ ) {
+ let local_set = LocalSet::new();
+ local_set.block_on(&runtime, async {
+ while let Some(spawn_task) = task_receiver.recv().await {
+ // Calls spawn_local(future)
+ (spawn_task)();
+ }
+ });
+
+ // If there are any tasks on the runtime associated with a LocalSet task
+ // that has already completed, but whose output has not yet been
+ // reported, let that task complete.
+ //
+ // Since the task_count is decremented when the runtime task exits,
+ // reading that counter lets us know if any such tasks completed during
+ // the call to `block_on`.
+ //
+ // Tasks on the LocalSet can't complete during this loop since they're
+ // stored on the LocalSet and we aren't accessing it.
+ let mut previous_task_count = task_count.load(Ordering::SeqCst);
+ loop {
+ // This call will also run tasks spawned on the runtime.
+ runtime.block_on(tokio::task::yield_now());
+ let new_task_count = task_count.load(Ordering::SeqCst);
+ if new_task_count == previous_task_count {
+ break;
+ } else {
+ previous_task_count = new_task_count;
+ }
+ }
+
+ // It's now no longer possible for a task on the runtime to be
+ // associated with a LocalSet task that has completed. Drop both the
+ // LocalSet and runtime to let tasks on the runtime be cancelled if and
+ // only if they are still on the LocalSet.
+ //
+ // Drop the LocalSet task first so that anyone awaiting the runtime
+ // JoinHandle will see the cancelled error after the LocalSet task
+ // destructor has completed.
+ drop(local_set);
+ drop(runtime);
+ }
+}
diff --git a/third_party/rust/tokio-util/src/time/delay_queue.rs b/third_party/rust/tokio-util/src/time/delay_queue.rs
new file mode 100644
index 0000000000..a0c5e5c5b0
--- /dev/null
+++ b/third_party/rust/tokio-util/src/time/delay_queue.rs
@@ -0,0 +1,1221 @@
+//! A queue of delayed elements.
+//!
+//! See [`DelayQueue`] for more details.
+//!
+//! [`DelayQueue`]: struct@DelayQueue
+
+use crate::time::wheel::{self, Wheel};
+
+use futures_core::ready;
+use tokio::time::{sleep_until, Duration, Instant, Sleep};
+
+use core::ops::{Index, IndexMut};
+use slab::Slab;
+use std::cmp;
+use std::collections::HashMap;
+use std::convert::From;
+use std::fmt;
+use std::fmt::Debug;
+use std::future::Future;
+use std::marker::PhantomData;
+use std::pin::Pin;
+use std::task::{self, Poll, Waker};
+
+/// A queue of delayed elements.
+///
+/// Once an element is inserted into the `DelayQueue`, it is yielded once the
+/// specified deadline has been reached.
+///
+/// # Usage
+///
+/// Elements are inserted into `DelayQueue` using the [`insert`] or
+/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is
+/// returned. The key is used to remove the entry or to change the deadline at
+/// which it should be yielded back.
+///
+/// Once delays have been configured, the `DelayQueue` is used via its
+/// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its
+/// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the
+/// current task will be notified once the deadline has been reached.
+///
+/// # `Stream` implementation
+///
+/// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have
+/// expired, no items are returned. In this case, `Poll::Pending` is returned and the
+/// current task is registered to be notified once the next item's delay has
+/// expired.
+///
+/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll`
+/// returns `Poll::Ready(None)`. This indicates that the stream has reached an end.
+/// However, if a new item is inserted *after*, `poll` will once again start
+/// returning items or `Poll::Pending`.
+///
+/// Items are returned ordered by their expirations. Items that are configured
+/// to expire first will be returned first. There are no ordering guarantees
+/// for items configured to expire at the same instant. Also note that delays are
+/// rounded to the closest millisecond.
+///
+/// # Implementation
+///
+/// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally
+/// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same
+/// performance and scalability benefits.
+///
+/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation,
+/// and allows reuse of the memory allocated for expired entires.
+///
+/// Capacity can be checked using [`capacity`] and allocated preemptively by using
+/// the [`reserve`] method.
+///
+/// # Usage
+///
+/// Using `DelayQueue` to manage cache entries.
+///
+/// ```rust,no_run
+/// use tokio_util::time::{DelayQueue, delay_queue};
+///
+/// use futures::ready;
+/// use std::collections::HashMap;
+/// use std::task::{Context, Poll};
+/// use std::time::Duration;
+/// # type CacheKey = String;
+/// # type Value = String;
+///
+/// struct Cache {
+/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>,
+/// expirations: DelayQueue<CacheKey>,
+/// }
+///
+/// const TTL_SECS: u64 = 30;
+///
+/// impl Cache {
+/// fn insert(&mut self, key: CacheKey, value: Value) {
+/// let delay = self.expirations
+/// .insert(key.clone(), Duration::from_secs(TTL_SECS));
+///
+/// self.entries.insert(key, (value, delay));
+/// }
+///
+/// fn get(&self, key: &CacheKey) -> Option<&Value> {
+/// self.entries.get(key)
+/// .map(|&(ref v, _)| v)
+/// }
+///
+/// fn remove(&mut self, key: &CacheKey) {
+/// if let Some((_, cache_key)) = self.entries.remove(key) {
+/// self.expirations.remove(&cache_key);
+/// }
+/// }
+///
+/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> {
+/// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) {
+/// self.entries.remove(entry.get_ref());
+/// }
+///
+/// Poll::Ready(())
+/// }
+/// }
+/// ```
+///
+/// [`insert`]: method@Self::insert
+/// [`insert_at`]: method@Self::insert_at
+/// [`Key`]: struct@Key
+/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html
+/// [`poll_expired`]: method@Self::poll_expired
+/// [`Stream::poll_expired`]: method@Self::poll_expired
+/// [`DelayQueue`]: struct@DelayQueue
+/// [`sleep`]: fn@tokio::time::sleep
+/// [`slab`]: slab
+/// [`capacity`]: method@Self::capacity
+/// [`reserve`]: method@Self::reserve
+#[derive(Debug)]
+pub struct DelayQueue<T> {
+ /// Stores data associated with entries
+ slab: SlabStorage<T>,
+
+ /// Lookup structure tracking all delays in the queue
+ wheel: Wheel<Stack<T>>,
+
+ /// Delays that were inserted when already expired. These cannot be stored
+ /// in the wheel
+ expired: Stack<T>,
+
+ /// Delay expiring when the *first* item in the queue expires
+ delay: Option<Pin<Box<Sleep>>>,
+
+ /// Wheel polling state
+ wheel_now: u64,
+
+ /// Instant at which the timer starts
+ start: Instant,
+
+ /// Waker that is invoked when we potentially need to reset the timer.
+ /// Because we lazily create the timer when the first entry is created, we
+ /// need to awaken any poller that polled us before that point.
+ waker: Option<Waker>,
+}
+
+#[derive(Default)]
+struct SlabStorage<T> {
+ inner: Slab<Data<T>>,
+
+ // A `compact` call requires a re-mapping of the `Key`s that were changed
+ // during the `compact` call of the `slab`. Since the keys that were given out
+ // cannot be changed retroactively we need to keep track of these re-mappings.
+ // The keys of `key_map` correspond to the old keys that were given out and
+ // the values to the `Key`s that were re-mapped by the `compact` call.
+ key_map: HashMap<Key, KeyInternal>,
+
+ // Index used to create new keys to hand out.
+ next_key_index: usize,
+
+ // Whether `compact` has been called, necessary in order to decide whether
+ // to include keys in `key_map`.
+ compact_called: bool,
+}
+
+impl<T> SlabStorage<T> {
+ pub(crate) fn with_capacity(capacity: usize) -> SlabStorage<T> {
+ SlabStorage {
+ inner: Slab::with_capacity(capacity),
+ key_map: HashMap::new(),
+ next_key_index: 0,
+ compact_called: false,
+ }
+ }
+
+ // Inserts data into the inner slab and re-maps keys if necessary
+ pub(crate) fn insert(&mut self, val: Data<T>) -> Key {
+ let mut key = KeyInternal::new(self.inner.insert(val));
+ let key_contained = self.key_map.contains_key(&key.into());
+
+ if key_contained {
+ // It's possible that a `compact` call creates capacitiy in `self.inner` in
+ // such a way that a `self.inner.insert` call creates a `key` which was
+ // previously given out during an `insert` call prior to the `compact` call.
+ // If `key` is contained in `self.key_map`, we have encountered this exact situation,
+ // We need to create a new key `key_to_give_out` and include the relation
+ // `key_to_give_out` -> `key` in `self.key_map`.
+ let key_to_give_out = self.create_new_key();
+ assert!(!self.key_map.contains_key(&key_to_give_out.into()));
+ self.key_map.insert(key_to_give_out.into(), key);
+ key = key_to_give_out;
+ } else if self.compact_called {
+ // Include an identity mapping in `self.key_map` in order to allow us to
+ // panic if a key that was handed out is removed more than once.
+ self.key_map.insert(key.into(), key);
+ }
+
+ key.into()
+ }
+
+ // Re-map the key in case compact was previously called.
+ // Note: Since we include identity mappings in key_map after compact was called,
+ // we have information about all keys that were handed out. In the case in which
+ // compact was called and we try to remove a Key that was previously removed
+ // we can detect invalid keys if no key is found in `key_map`. This is necessary
+ // in order to prevent situations in which a previously removed key
+ // corresponds to a re-mapped key internally and which would then be incorrectly
+ // removed from the slab.
+ //
+ // Example to illuminate this problem:
+ //
+ // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we
+ // were to remove 1 again, we would not find it inside `key_map` anymore.
+ // If we were to imply from this that no re-mapping was necessary, we would
+ // incorrectly remove 1 from `self.slab.inner`, which corresponds to the
+ // handed-out key 2.
+ pub(crate) fn remove(&mut self, key: &Key) -> Data<T> {
+ let remapped_key = if self.compact_called {
+ match self.key_map.remove(key) {
+ Some(key_internal) => key_internal,
+ None => panic!("invalid key"),
+ }
+ } else {
+ (*key).into()
+ };
+
+ self.inner.remove(remapped_key.index)
+ }
+
+ pub(crate) fn shrink_to_fit(&mut self) {
+ self.inner.shrink_to_fit();
+ self.key_map.shrink_to_fit();
+ }
+
+ pub(crate) fn compact(&mut self) {
+ if !self.compact_called {
+ for (key, _) in self.inner.iter() {
+ self.key_map.insert(Key::new(key), KeyInternal::new(key));
+ }
+ }
+
+ let mut remapping = HashMap::new();
+ self.inner.compact(|_, from, to| {
+ remapping.insert(from, to);
+ true
+ });
+
+ // At this point `key_map` contains a mapping for every element.
+ for internal_key in self.key_map.values_mut() {
+ if let Some(new_internal_key) = remapping.get(&internal_key.index) {
+ *internal_key = KeyInternal::new(*new_internal_key);
+ }
+ }
+
+ if self.key_map.capacity() > 2 * self.key_map.len() {
+ self.key_map.shrink_to_fit();
+ }
+
+ self.compact_called = true;
+ }
+
+ // Tries to re-map a `Key` that was given out to the user to its
+ // corresponding internal key.
+ fn remap_key(&self, key: &Key) -> Option<KeyInternal> {
+ let key_map = &self.key_map;
+ if self.compact_called {
+ key_map.get(&*key).copied()
+ } else {
+ Some((*key).into())
+ }
+ }
+
+ fn create_new_key(&mut self) -> KeyInternal {
+ while self.key_map.contains_key(&Key::new(self.next_key_index)) {
+ self.next_key_index = self.next_key_index.wrapping_add(1);
+ }
+
+ KeyInternal::new(self.next_key_index)
+ }
+
+ pub(crate) fn len(&self) -> usize {
+ self.inner.len()
+ }
+
+ pub(crate) fn capacity(&self) -> usize {
+ self.inner.capacity()
+ }
+
+ pub(crate) fn clear(&mut self) {
+ self.inner.clear();
+ self.key_map.clear();
+ self.compact_called = false;
+ }
+
+ pub(crate) fn reserve(&mut self, additional: usize) {
+ self.inner.reserve(additional);
+
+ if self.compact_called {
+ self.key_map.reserve(additional);
+ }
+ }
+
+ pub(crate) fn is_empty(&self) -> bool {
+ self.inner.is_empty()
+ }
+
+ pub(crate) fn contains(&self, key: &Key) -> bool {
+ let remapped_key = self.remap_key(key);
+
+ match remapped_key {
+ Some(internal_key) => self.inner.contains(internal_key.index),
+ None => false,
+ }
+ }
+}
+
+impl<T> fmt::Debug for SlabStorage<T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ if fmt.alternate() {
+ fmt.debug_map().entries(self.inner.iter()).finish()
+ } else {
+ fmt.debug_struct("Slab")
+ .field("len", &self.len())
+ .field("cap", &self.capacity())
+ .finish()
+ }
+ }
+}
+
+impl<T> Index<Key> for SlabStorage<T> {
+ type Output = Data<T>;
+
+ fn index(&self, key: Key) -> &Self::Output {
+ let remapped_key = self.remap_key(&key);
+
+ match remapped_key {
+ Some(internal_key) => &self.inner[internal_key.index],
+ None => panic!("Invalid index {}", key.index),
+ }
+ }
+}
+
+impl<T> IndexMut<Key> for SlabStorage<T> {
+ fn index_mut(&mut self, key: Key) -> &mut Data<T> {
+ let remapped_key = self.remap_key(&key);
+
+ match remapped_key {
+ Some(internal_key) => &mut self.inner[internal_key.index],
+ None => panic!("Invalid index {}", key.index),
+ }
+ }
+}
+
+/// An entry in `DelayQueue` that has expired and been removed.
+///
+/// Values are returned by [`DelayQueue::poll_expired`].
+///
+/// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired
+#[derive(Debug)]
+pub struct Expired<T> {
+ /// The data stored in the queue
+ data: T,
+
+ /// The expiration time
+ deadline: Instant,
+
+ /// The key associated with the entry
+ key: Key,
+}
+
+/// Token to a value stored in a `DelayQueue`.
+///
+/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`]
+/// documentation for more details.
+///
+/// [`DelayQueue`]: struct@DelayQueue
+/// [`DelayQueue::insert`]: method@DelayQueue::insert
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Key {
+ index: usize,
+}
+
+// Whereas `Key` is given out to users that use `DelayQueue`, internally we use
+// `KeyInternal` as the key type in order to make the logic of mapping between keys
+// as a result of `compact` calls clearer.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+struct KeyInternal {
+ index: usize,
+}
+
+#[derive(Debug)]
+struct Stack<T> {
+ /// Head of the stack
+ head: Option<Key>,
+ _p: PhantomData<fn() -> T>,
+}
+
+#[derive(Debug)]
+struct Data<T> {
+ /// The data being stored in the queue and will be returned at the requested
+ /// instant.
+ inner: T,
+
+ /// The instant at which the item is returned.
+ when: u64,
+
+ /// Set to true when stored in the `expired` queue
+ expired: bool,
+
+ /// Next entry in the stack
+ next: Option<Key>,
+
+ /// Previous entry in the stack
+ prev: Option<Key>,
+}
+
+/// Maximum number of entries the queue can handle
+const MAX_ENTRIES: usize = (1 << 30) - 1;
+
+impl<T> DelayQueue<T> {
+ /// Creates a new, empty, `DelayQueue`.
+ ///
+ /// The queue will not allocate storage until items are inserted into it.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use tokio_util::time::DelayQueue;
+ /// let delay_queue: DelayQueue<u32> = DelayQueue::new();
+ /// ```
+ pub fn new() -> DelayQueue<T> {
+ DelayQueue::with_capacity(0)
+ }
+
+ /// Creates a new, empty, `DelayQueue` with the specified capacity.
+ ///
+ /// The queue will be able to hold at least `capacity` elements without
+ /// reallocating. If `capacity` is 0, the queue will not allocate for
+ /// storage.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use tokio_util::time::DelayQueue;
+ /// # use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::with_capacity(10);
+ ///
+ /// // These insertions are done without further allocation
+ /// for i in 0..10 {
+ /// delay_queue.insert(i, Duration::from_secs(i));
+ /// }
+ ///
+ /// // This will make the queue allocate additional storage
+ /// delay_queue.insert(11, Duration::from_secs(11));
+ /// # }
+ /// ```
+ pub fn with_capacity(capacity: usize) -> DelayQueue<T> {
+ DelayQueue {
+ wheel: Wheel::new(),
+ slab: SlabStorage::with_capacity(capacity),
+ expired: Stack::default(),
+ delay: None,
+ wheel_now: 0,
+ start: Instant::now(),
+ waker: None,
+ }
+ }
+
+ /// Inserts `value` into the queue set to expire at a specific instant in
+ /// time.
+ ///
+ /// This function is identical to `insert`, but takes an `Instant` instead
+ /// of a `Duration`.
+ ///
+ /// `value` is stored in the queue until `when` is reached. At which point,
+ /// `value` will be returned from [`poll_expired`]. If `when` has already been
+ /// reached, then `value` is immediately made available to poll.
+ ///
+ /// The return value represents the insertion and is used as an argument to
+ /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once
+ /// `value` is removed from the queue either by calling [`poll_expired`] after
+ /// `when` is reached or by calling [`remove`]. At this point, the caller
+ /// must take care to not use the returned [`Key`] again as it may reference
+ /// a different item in the queue.
+ ///
+ /// See [type] level documentation for more details.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `when` is too far in the future.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio::time::{Duration, Instant};
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert_at(
+ /// "foo", Instant::now() + Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ /// [`remove`]: method@Self::remove
+ /// [`reset`]: method@Self::reset
+ /// [`Key`]: struct@Key
+ /// [type]: #
+ pub fn insert_at(&mut self, value: T, when: Instant) -> Key {
+ assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded");
+
+ // Normalize the deadline. Values cannot be set to expire in the past.
+ let when = self.normalize_deadline(when);
+
+ // Insert the value in the store
+ let key = self.slab.insert(Data {
+ inner: value,
+ when,
+ expired: false,
+ next: None,
+ prev: None,
+ });
+
+ self.insert_idx(when, key);
+
+ // Set a new delay if the current's deadline is later than the one of the new item
+ let should_set_delay = if let Some(ref delay) = self.delay {
+ let current_exp = self.normalize_deadline(delay.deadline());
+ current_exp > when
+ } else {
+ true
+ };
+
+ if should_set_delay {
+ if let Some(waker) = self.waker.take() {
+ waker.wake();
+ }
+
+ let delay_time = self.start + Duration::from_millis(when);
+ if let Some(ref mut delay) = &mut self.delay {
+ delay.as_mut().reset(delay_time);
+ } else {
+ self.delay = Some(Box::pin(sleep_until(delay_time)));
+ }
+ }
+
+ key
+ }
+
+ /// Attempts to pull out the next value of the delay queue, registering the
+ /// current task for wakeup if the value is not yet available, and returning
+ /// `None` if the queue is exhausted.
+ pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Expired<T>>> {
+ if !self
+ .waker
+ .as_ref()
+ .map(|w| w.will_wake(cx.waker()))
+ .unwrap_or(false)
+ {
+ self.waker = Some(cx.waker().clone());
+ }
+
+ let item = ready!(self.poll_idx(cx));
+ Poll::Ready(item.map(|key| {
+ let data = self.slab.remove(&key);
+ debug_assert!(data.next.is_none());
+ debug_assert!(data.prev.is_none());
+
+ Expired {
+ key,
+ data: data.inner,
+ deadline: self.start + Duration::from_millis(data.when),
+ }
+ }))
+ }
+
+ /// Inserts `value` into the queue set to expire after the requested duration
+ /// elapses.
+ ///
+ /// This function is identical to `insert_at`, but takes a `Duration`
+ /// instead of an `Instant`.
+ ///
+ /// `value` is stored in the queue until `timeout` duration has
+ /// elapsed after `insert` was called. At that point, `value` will
+ /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of
+ /// zero, then `value` is immediately made available to poll.
+ ///
+ /// The return value represents the insertion and is used as an
+ /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a
+ /// token and is reused once `value` is removed from the queue
+ /// either by calling [`poll_expired`] after `timeout` has elapsed
+ /// or by calling [`remove`]. At this point, the caller must not
+ /// use the returned [`Key`] again as it may reference a different
+ /// item in the queue.
+ ///
+ /// See [type] level documentation for more details.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `timeout` is greater than the maximum
+ /// duration supported by the timer in the current `Runtime`.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ /// [`remove`]: method@Self::remove
+ /// [`reset`]: method@Self::reset
+ /// [`Key`]: struct@Key
+ /// [type]: #
+ pub fn insert(&mut self, value: T, timeout: Duration) -> Key {
+ self.insert_at(value, Instant::now() + timeout)
+ }
+
+ fn insert_idx(&mut self, when: u64, key: Key) {
+ use self::wheel::{InsertError, Stack};
+
+ // Register the deadline with the timer wheel
+ match self.wheel.insert(when, key, &mut self.slab) {
+ Ok(_) => {}
+ Err((_, InsertError::Elapsed)) => {
+ self.slab[key].expired = true;
+ // The delay is already expired, store it in the expired queue
+ self.expired.push(key, &mut self.slab);
+ }
+ Err((_, err)) => panic!("invalid deadline; err={:?}", err),
+ }
+ }
+
+ /// Removes the key from the expired queue or the timer wheel
+ /// depending on its expiration status.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the key is not contained in the expired queue or the wheel.
+ fn remove_key(&mut self, key: &Key) {
+ use crate::time::wheel::Stack;
+
+ // Special case the `expired` queue
+ if self.slab[*key].expired {
+ self.expired.remove(key, &mut self.slab);
+ } else {
+ self.wheel.remove(key, &mut self.slab);
+ }
+ }
+
+ /// Removes the item associated with `key` from the queue.
+ ///
+ /// There must be an item associated with `key`. The function returns the
+ /// removed item as well as the `Instant` at which it will the delay will
+ /// have expired.
+ ///
+ /// # Panics
+ ///
+ /// The function panics if `key` is not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ pub fn remove(&mut self, key: &Key) -> Expired<T> {
+ let prev_deadline = self.next_deadline();
+
+ self.remove_key(key);
+ let data = self.slab.remove(key);
+
+ let next_deadline = self.next_deadline();
+ if prev_deadline != next_deadline {
+ match (next_deadline, &mut self.delay) {
+ (None, _) => self.delay = None,
+ (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline),
+ (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))),
+ }
+ }
+
+ Expired {
+ key: Key::new(key.index),
+ data: data.inner,
+ deadline: self.start + Duration::from_millis(data.when),
+ }
+ }
+
+ /// Sets the delay of the item associated with `key` to expire at `when`.
+ ///
+ /// This function is identical to `reset` but takes an `Instant` instead of
+ /// a `Duration`.
+ ///
+ /// The item remains in the queue but the delay is set to expire at `when`.
+ /// If `when` is in the past, then the item is immediately made available to
+ /// the caller.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `when` is too far in the future or if `key` is
+ /// not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio::time::{Duration, Instant};
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // "foo" is scheduled to be returned in 5 seconds
+ ///
+ /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10));
+ ///
+ /// // "foo" is now scheduled to be returned in 10 seconds
+ /// # }
+ /// ```
+ pub fn reset_at(&mut self, key: &Key, when: Instant) {
+ self.remove_key(key);
+
+ // Normalize the deadline. Values cannot be set to expire in the past.
+ let when = self.normalize_deadline(when);
+
+ self.slab[*key].when = when;
+ self.slab[*key].expired = false;
+
+ self.insert_idx(when, *key);
+
+ let next_deadline = self.next_deadline();
+ if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) {
+ // This should awaken us if necessary (ie, if already expired)
+ delay.as_mut().reset(deadline);
+ }
+ }
+
+ /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation.
+ /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab
+ /// to the number of elements still contained in it, because elements cannot be moved to a different
+ /// index. To decrease the capacity to the size of the slab use [`compact`].
+ ///
+ /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is
+ /// shrunk in place. Repeated calls run in O(1) though.
+ ///
+ /// [`compact`]: method@Self::compact
+ pub fn shrink_to_fit(&mut self) {
+ self.slab.shrink_to_fit();
+ }
+
+ /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation,
+ /// to the number of elements that are contained in it.
+ ///
+ /// This methods runs in O(n).
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::with_capacity(10);
+ ///
+ /// let key1 = delay_queue.insert(5, Duration::from_secs(5));
+ /// let key2 = delay_queue.insert(10, Duration::from_secs(10));
+ /// let key3 = delay_queue.insert(15, Duration::from_secs(15));
+ ///
+ /// delay_queue.remove(&key2);
+ ///
+ /// delay_queue.compact();
+ /// assert_eq!(delay_queue.capacity(), 2);
+ /// # }
+ /// ```
+ pub fn compact(&mut self) {
+ self.slab.compact();
+ }
+
+ /// Returns the next time to poll as determined by the wheel
+ fn next_deadline(&mut self) -> Option<Instant> {
+ self.wheel
+ .poll_at()
+ .map(|poll_at| self.start + Duration::from_millis(poll_at))
+ }
+
+ /// Sets the delay of the item associated with `key` to expire after
+ /// `timeout`.
+ ///
+ /// This function is identical to `reset_at` but takes a `Duration` instead
+ /// of an `Instant`.
+ ///
+ /// The item remains in the queue but the delay is set to expire after
+ /// `timeout`. If `timeout` is zero, then the item is immediately made
+ /// available to the caller.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `timeout` is greater than the maximum supported
+ /// duration or if `key` is not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // "foo" is scheduled to be returned in 5 seconds
+ ///
+ /// delay_queue.reset(&key, Duration::from_secs(10));
+ ///
+ /// // "foo"is now scheduled to be returned in 10 seconds
+ /// # }
+ /// ```
+ pub fn reset(&mut self, key: &Key, timeout: Duration) {
+ self.reset_at(key, Instant::now() + timeout);
+ }
+
+ /// Clears the queue, removing all items.
+ ///
+ /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`.
+ ///
+ /// Note that this method has no effect on the allocated capacity.
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ ///
+ /// delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// assert!(!delay_queue.is_empty());
+ ///
+ /// delay_queue.clear();
+ ///
+ /// assert!(delay_queue.is_empty());
+ /// # }
+ /// ```
+ pub fn clear(&mut self) {
+ self.slab.clear();
+ self.expired = Stack::default();
+ self.wheel = Wheel::new();
+ self.delay = None;
+ }
+
+ /// Returns the number of elements the queue can hold without reallocating.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
+ /// assert_eq!(delay_queue.capacity(), 10);
+ /// ```
+ pub fn capacity(&self) -> usize {
+ self.slab.capacity()
+ }
+
+ /// Returns the number of elements currently in the queue.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
+ /// assert_eq!(delay_queue.len(), 0);
+ /// delay_queue.insert(3, Duration::from_secs(5));
+ /// assert_eq!(delay_queue.len(), 1);
+ /// # }
+ /// ```
+ pub fn len(&self) -> usize {
+ self.slab.len()
+ }
+
+ /// Reserves capacity for at least `additional` more items to be queued
+ /// without allocating.
+ ///
+ /// `reserve` does nothing if the queue already has sufficient capacity for
+ /// `additional` more values. If more capacity is required, a new segment of
+ /// memory will be allocated and all existing values will be copied into it.
+ /// As such, if the queue is already very large, a call to `reserve` can end
+ /// up being expensive.
+ ///
+ /// The queue may reserve more than `additional` extra space in order to
+ /// avoid frequent reallocations.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the new capacity exceeds the maximum number of entries the
+ /// queue can contain.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ ///
+ /// delay_queue.insert("hello", Duration::from_secs(10));
+ /// delay_queue.reserve(10);
+ ///
+ /// assert!(delay_queue.capacity() >= 11);
+ /// # }
+ /// ```
+ pub fn reserve(&mut self, additional: usize) {
+ self.slab.reserve(additional);
+ }
+
+ /// Returns `true` if there are no items in the queue.
+ ///
+ /// Note that this function returns `false` even if all items have not yet
+ /// expired and a call to `poll` will return `Poll::Pending`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// assert!(delay_queue.is_empty());
+ ///
+ /// delay_queue.insert("hello", Duration::from_secs(5));
+ /// assert!(!delay_queue.is_empty());
+ /// # }
+ /// ```
+ pub fn is_empty(&self) -> bool {
+ self.slab.is_empty()
+ }
+
+ /// Polls the queue, returning the index of the next slot in the slab that
+ /// should be returned.
+ ///
+ /// A slot should be returned when the associated deadline has been reached.
+ fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Key>> {
+ use self::wheel::Stack;
+
+ let expired = self.expired.pop(&mut self.slab);
+
+ if expired.is_some() {
+ return Poll::Ready(expired);
+ }
+
+ loop {
+ if let Some(ref mut delay) = self.delay {
+ if !delay.is_elapsed() {
+ ready!(Pin::new(&mut *delay).poll(cx));
+ }
+
+ let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down);
+
+ self.wheel_now = now;
+ }
+
+ // We poll the wheel to get the next value out before finding the next deadline.
+ let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab);
+
+ self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when)));
+
+ if let Some(idx) = wheel_idx {
+ return Poll::Ready(Some(idx));
+ }
+
+ if self.delay.is_none() {
+ return Poll::Ready(None);
+ }
+ }
+ }
+
+ fn normalize_deadline(&self, when: Instant) -> u64 {
+ let when = if when < self.start {
+ 0
+ } else {
+ crate::time::ms(when - self.start, crate::time::Round::Up)
+ };
+
+ cmp::max(when, self.wheel.elapsed())
+ }
+}
+
+// We never put `T` in a `Pin`...
+impl<T> Unpin for DelayQueue<T> {}
+
+impl<T> Default for DelayQueue<T> {
+ fn default() -> DelayQueue<T> {
+ DelayQueue::new()
+ }
+}
+
+impl<T> futures_core::Stream for DelayQueue<T> {
+ // DelayQueue seems much more specific, where a user may care that it
+ // has reached capacity, so return those errors instead of panicking.
+ type Item = Expired<T>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
+ DelayQueue::poll_expired(self.get_mut(), cx)
+ }
+}
+
+impl<T> wheel::Stack for Stack<T> {
+ type Owned = Key;
+ type Borrowed = Key;
+ type Store = SlabStorage<T>;
+
+ fn is_empty(&self) -> bool {
+ self.head.is_none()
+ }
+
+ fn push(&mut self, item: Self::Owned, store: &mut Self::Store) {
+ // Ensure the entry is not already in a stack.
+ debug_assert!(store[item].next.is_none());
+ debug_assert!(store[item].prev.is_none());
+
+ // Remove the old head entry
+ let old = self.head.take();
+
+ if let Some(idx) = old {
+ store[idx].prev = Some(item);
+ }
+
+ store[item].next = old;
+ self.head = Some(item);
+ }
+
+ fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> {
+ if let Some(key) = self.head {
+ self.head = store[key].next;
+
+ if let Some(idx) = self.head {
+ store[idx].prev = None;
+ }
+
+ store[key].next = None;
+ debug_assert!(store[key].prev.is_none());
+
+ Some(key)
+ } else {
+ None
+ }
+ }
+
+ fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) {
+ let key = *item;
+ assert!(store.contains(item));
+
+ // Ensure that the entry is in fact contained by the stack
+ debug_assert!({
+ // This walks the full linked list even if an entry is found.
+ let mut next = self.head;
+ let mut contains = false;
+
+ while let Some(idx) = next {
+ let data = &store[idx];
+
+ if idx == *item {
+ debug_assert!(!contains);
+ contains = true;
+ }
+
+ next = data.next;
+ }
+
+ contains
+ });
+
+ if let Some(next) = store[key].next {
+ store[next].prev = store[key].prev;
+ }
+
+ if let Some(prev) = store[key].prev {
+ store[prev].next = store[key].next;
+ } else {
+ self.head = store[key].next;
+ }
+
+ store[key].next = None;
+ store[key].prev = None;
+ }
+
+ fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 {
+ store[*item].when
+ }
+}
+
+impl<T> Default for Stack<T> {
+ fn default() -> Stack<T> {
+ Stack {
+ head: None,
+ _p: PhantomData,
+ }
+ }
+}
+
+impl Key {
+ pub(crate) fn new(index: usize) -> Key {
+ Key { index }
+ }
+}
+
+impl KeyInternal {
+ pub(crate) fn new(index: usize) -> KeyInternal {
+ KeyInternal { index }
+ }
+}
+
+impl From<Key> for KeyInternal {
+ fn from(item: Key) -> Self {
+ KeyInternal::new(item.index)
+ }
+}
+
+impl From<KeyInternal> for Key {
+ fn from(item: KeyInternal) -> Self {
+ Key::new(item.index)
+ }
+}
+
+impl<T> Expired<T> {
+ /// Returns a reference to the inner value.
+ pub fn get_ref(&self) -> &T {
+ &self.data
+ }
+
+ /// Returns a mutable reference to the inner value.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.data
+ }
+
+ /// Consumes `self` and returns the inner value.
+ pub fn into_inner(self) -> T {
+ self.data
+ }
+
+ /// Returns the deadline that the expiration was set to.
+ pub fn deadline(&self) -> Instant {
+ self.deadline
+ }
+
+ /// Returns the key that the expiration is indexed by.
+ pub fn key(&self) -> Key {
+ self.key
+ }
+}
diff --git a/third_party/rust/tokio-util/src/time/mod.rs b/third_party/rust/tokio-util/src/time/mod.rs
new file mode 100644
index 0000000000..2d34008360
--- /dev/null
+++ b/third_party/rust/tokio-util/src/time/mod.rs
@@ -0,0 +1,47 @@
+//! Additional utilities for tracking time.
+//!
+//! This module provides additional utilities for executing code after a set period
+//! of time. Currently there is only one:
+//!
+//! * `DelayQueue`: A queue where items are returned once the requested delay
+//! has expired.
+//!
+//! This type must be used from within the context of the `Runtime`.
+
+use std::time::Duration;
+
+mod wheel;
+
+pub mod delay_queue;
+
+#[doc(inline)]
+pub use delay_queue::DelayQueue;
+
+// ===== Internal utils =====
+
+enum Round {
+ Up,
+ Down,
+}
+
+/// Convert a `Duration` to milliseconds, rounding up and saturating at
+/// `u64::MAX`.
+///
+/// The saturating is fine because `u64::MAX` milliseconds are still many
+/// million years.
+#[inline]
+fn ms(duration: Duration, round: Round) -> u64 {
+ const NANOS_PER_MILLI: u32 = 1_000_000;
+ const MILLIS_PER_SEC: u64 = 1_000;
+
+ // Round up.
+ let millis = match round {
+ Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI,
+ Round::Down => duration.subsec_millis(),
+ };
+
+ duration
+ .as_secs()
+ .saturating_mul(MILLIS_PER_SEC)
+ .saturating_add(u64::from(millis))
+}
diff --git a/third_party/rust/tokio-util/src/time/wheel/level.rs b/third_party/rust/tokio-util/src/time/wheel/level.rs
new file mode 100644
index 0000000000..8ea30af30f
--- /dev/null
+++ b/third_party/rust/tokio-util/src/time/wheel/level.rs
@@ -0,0 +1,253 @@
+use crate::time::wheel::Stack;
+
+use std::fmt;
+
+/// Wheel for a single level in the timer. This wheel contains 64 slots.
+pub(crate) struct Level<T> {
+ level: usize,
+
+ /// Bit field tracking which slots currently contain entries.
+ ///
+ /// Using a bit field to track slots that contain entries allows avoiding a
+ /// scan to find entries. This field is updated when entries are added or
+ /// removed from a slot.
+ ///
+ /// The least-significant bit represents slot zero.
+ occupied: u64,
+
+ /// Slots
+ slot: [T; LEVEL_MULT],
+}
+
+/// Indicates when a slot must be processed next.
+#[derive(Debug)]
+pub(crate) struct Expiration {
+ /// The level containing the slot.
+ pub(crate) level: usize,
+
+ /// The slot index.
+ pub(crate) slot: usize,
+
+ /// The instant at which the slot needs to be processed.
+ pub(crate) deadline: u64,
+}
+
+/// Level multiplier.
+///
+/// Being a power of 2 is very important.
+const LEVEL_MULT: usize = 64;
+
+impl<T: Stack> Level<T> {
+ pub(crate) fn new(level: usize) -> Level<T> {
+ // Rust's derived implementations for arrays require that the value
+ // contained by the array be `Copy`. So, here we have to manually
+ // initialize every single slot.
+ macro_rules! s {
+ () => {
+ T::default()
+ };
+ }
+
+ Level {
+ level,
+ occupied: 0,
+ slot: [
+ // It does not look like the necessary traits are
+ // derived for [T; 64].
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ ],
+ }
+ }
+
+ /// Finds the slot that needs to be processed next and returns the slot and
+ /// `Instant` at which this slot must be processed.
+ pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> {
+ // Use the `occupied` bit field to get the index of the next slot that
+ // needs to be processed.
+ let slot = match self.next_occupied_slot(now) {
+ Some(slot) => slot,
+ None => return None,
+ };
+
+ // From the slot index, calculate the `Instant` at which it needs to be
+ // processed. This value *must* be in the future with respect to `now`.
+
+ let level_range = level_range(self.level);
+ let slot_range = slot_range(self.level);
+
+ // TODO: This can probably be simplified w/ power of 2 math
+ let level_start = now - (now % level_range);
+ let deadline = level_start + slot as u64 * slot_range;
+
+ debug_assert!(
+ deadline >= now,
+ "deadline={}; now={}; level={}; slot={}; occupied={:b}",
+ deadline,
+ now,
+ self.level,
+ slot,
+ self.occupied
+ );
+
+ Some(Expiration {
+ level: self.level,
+ slot,
+ deadline,
+ })
+ }
+
+ fn next_occupied_slot(&self, now: u64) -> Option<usize> {
+ if self.occupied == 0 {
+ return None;
+ }
+
+ // Get the slot for now using Maths
+ let now_slot = (now / slot_range(self.level)) as usize;
+ let occupied = self.occupied.rotate_right(now_slot as u32);
+ let zeros = occupied.trailing_zeros() as usize;
+ let slot = (zeros + now_slot) % 64;
+
+ Some(slot)
+ }
+
+ pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) {
+ let slot = slot_for(when, self.level);
+
+ self.slot[slot].push(item, store);
+ self.occupied |= occupied_bit(slot);
+ }
+
+ pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) {
+ let slot = slot_for(when, self.level);
+
+ self.slot[slot].remove(item, store);
+
+ if self.slot[slot].is_empty() {
+ // The bit is currently set
+ debug_assert!(self.occupied & occupied_bit(slot) != 0);
+
+ // Unset the bit
+ self.occupied ^= occupied_bit(slot);
+ }
+ }
+
+ pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> {
+ let ret = self.slot[slot].pop(store);
+
+ if ret.is_some() && self.slot[slot].is_empty() {
+ // The bit is currently set
+ debug_assert!(self.occupied & occupied_bit(slot) != 0);
+
+ self.occupied ^= occupied_bit(slot);
+ }
+
+ ret
+ }
+}
+
+impl<T> fmt::Debug for Level<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Level")
+ .field("occupied", &self.occupied)
+ .finish()
+ }
+}
+
+fn occupied_bit(slot: usize) -> u64 {
+ 1 << slot
+}
+
+fn slot_range(level: usize) -> u64 {
+ LEVEL_MULT.pow(level as u32) as u64
+}
+
+fn level_range(level: usize) -> u64 {
+ LEVEL_MULT as u64 * slot_range(level)
+}
+
+/// Convert a duration (milliseconds) and a level to a slot position
+fn slot_for(duration: u64, level: usize) -> usize {
+ ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize
+}
+
+#[cfg(all(test, not(loom)))]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_slot_for() {
+ for pos in 0..64 {
+ assert_eq!(pos as usize, slot_for(pos, 0));
+ }
+
+ for level in 1..5 {
+ for pos in level..64 {
+ let a = pos * 64_usize.pow(level as u32);
+ assert_eq!(pos as usize, slot_for(a as u64, level));
+ }
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/time/wheel/mod.rs b/third_party/rust/tokio-util/src/time/wheel/mod.rs
new file mode 100644
index 0000000000..4191e401df
--- /dev/null
+++ b/third_party/rust/tokio-util/src/time/wheel/mod.rs
@@ -0,0 +1,314 @@
+mod level;
+pub(crate) use self::level::Expiration;
+use self::level::Level;
+
+mod stack;
+pub(crate) use self::stack::Stack;
+
+use std::borrow::Borrow;
+use std::fmt::Debug;
+use std::usize;
+
+/// Timing wheel implementation.
+///
+/// This type provides the hashed timing wheel implementation that backs `Timer`
+/// and `DelayQueue`.
+///
+/// The structure is generic over `T: Stack`. This allows handling timeout data
+/// being stored on the heap or in a slab. In order to support the latter case,
+/// the slab must be passed into each function allowing the implementation to
+/// lookup timer entries.
+///
+/// See `Timer` documentation for some implementation notes.
+#[derive(Debug)]
+pub(crate) struct Wheel<T> {
+ /// The number of milliseconds elapsed since the wheel started.
+ elapsed: u64,
+
+ /// Timer wheel.
+ ///
+ /// Levels:
+ ///
+ /// * 1 ms slots / 64 ms range
+ /// * 64 ms slots / ~ 4 sec range
+ /// * ~ 4 sec slots / ~ 4 min range
+ /// * ~ 4 min slots / ~ 4 hr range
+ /// * ~ 4 hr slots / ~ 12 day range
+ /// * ~ 12 day slots / ~ 2 yr range
+ levels: Vec<Level<T>>,
+}
+
+/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots
+/// each, the timer is able to track time up to 2 years into the future with a
+/// precision of 1 millisecond.
+const NUM_LEVELS: usize = 6;
+
+/// The maximum duration of a delay
+const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1;
+
+#[derive(Debug)]
+pub(crate) enum InsertError {
+ Elapsed,
+ Invalid,
+}
+
+impl<T> Wheel<T>
+where
+ T: Stack,
+{
+ /// Create a new timing wheel
+ pub(crate) fn new() -> Wheel<T> {
+ let levels = (0..NUM_LEVELS).map(Level::new).collect();
+
+ Wheel { elapsed: 0, levels }
+ }
+
+ /// Return the number of milliseconds that have elapsed since the timing
+ /// wheel's creation.
+ pub(crate) fn elapsed(&self) -> u64 {
+ self.elapsed
+ }
+
+ /// Insert an entry into the timing wheel.
+ ///
+ /// # Arguments
+ ///
+ /// * `when`: is the instant at which the entry should be fired. It is
+ /// represented as the number of milliseconds since the creation
+ /// of the timing wheel.
+ ///
+ /// * `item`: The item to insert into the wheel.
+ ///
+ /// * `store`: The slab or `()` when using heap storage.
+ ///
+ /// # Return
+ ///
+ /// Returns `Ok` when the item is successfully inserted, `Err` otherwise.
+ ///
+ /// `Err(Elapsed)` indicates that `when` represents an instant that has
+ /// already passed. In this case, the caller should fire the timeout
+ /// immediately.
+ ///
+ /// `Err(Invalid)` indicates an invalid `when` argument as been supplied.
+ pub(crate) fn insert(
+ &mut self,
+ when: u64,
+ item: T::Owned,
+ store: &mut T::Store,
+ ) -> Result<(), (T::Owned, InsertError)> {
+ if when <= self.elapsed {
+ return Err((item, InsertError::Elapsed));
+ } else if when - self.elapsed > MAX_DURATION {
+ return Err((item, InsertError::Invalid));
+ }
+
+ // Get the level at which the entry should be stored
+ let level = self.level_for(when);
+
+ self.levels[level].add_entry(when, item, store);
+
+ debug_assert!({
+ self.levels[level]
+ .next_expiration(self.elapsed)
+ .map(|e| e.deadline >= self.elapsed)
+ .unwrap_or(true)
+ });
+
+ Ok(())
+ }
+
+ /// Remove `item` from the timing wheel.
+ pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) {
+ let when = T::when(item, store);
+
+ assert!(
+ self.elapsed <= when,
+ "elapsed={}; when={}",
+ self.elapsed,
+ when
+ );
+
+ let level = self.level_for(when);
+
+ self.levels[level].remove_entry(when, item, store);
+ }
+
+ /// Instant at which to poll
+ pub(crate) fn poll_at(&self) -> Option<u64> {
+ self.next_expiration().map(|expiration| expiration.deadline)
+ }
+
+ /// Advances the timer up to the instant represented by `now`.
+ pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> {
+ loop {
+ let expiration = self.next_expiration().and_then(|expiration| {
+ if expiration.deadline > now {
+ None
+ } else {
+ Some(expiration)
+ }
+ });
+
+ match expiration {
+ Some(ref expiration) => {
+ if let Some(item) = self.poll_expiration(expiration, store) {
+ return Some(item);
+ }
+
+ self.set_elapsed(expiration.deadline);
+ }
+ None => {
+ // in this case the poll did not indicate an expiration
+ // _and_ we were not able to find a next expiration in
+ // the current list of timers. advance to the poll's
+ // current time and do nothing else.
+ self.set_elapsed(now);
+ return None;
+ }
+ }
+ }
+ }
+
+ /// Returns the instant at which the next timeout expires.
+ fn next_expiration(&self) -> Option<Expiration> {
+ // Check all levels
+ for level in 0..NUM_LEVELS {
+ if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) {
+ // There cannot be any expirations at a higher level that happen
+ // before this one.
+ debug_assert!(self.no_expirations_before(level + 1, expiration.deadline));
+
+ return Some(expiration);
+ }
+ }
+
+ None
+ }
+
+ /// Used for debug assertions
+ fn no_expirations_before(&self, start_level: usize, before: u64) -> bool {
+ let mut res = true;
+
+ for l2 in start_level..NUM_LEVELS {
+ if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) {
+ if e2.deadline < before {
+ res = false;
+ }
+ }
+ }
+
+ res
+ }
+
+ /// iteratively find entries that are between the wheel's current
+ /// time and the expiration time. for each in that population either
+ /// return it for notification (in the case of the last level) or tier
+ /// it down to the next level (in all other cases).
+ pub(crate) fn poll_expiration(
+ &mut self,
+ expiration: &Expiration,
+ store: &mut T::Store,
+ ) -> Option<T::Owned> {
+ while let Some(item) = self.pop_entry(expiration, store) {
+ if expiration.level == 0 {
+ debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline);
+
+ return Some(item);
+ } else {
+ let when = T::when(item.borrow(), store);
+
+ let next_level = expiration.level - 1;
+
+ self.levels[next_level].add_entry(when, item, store);
+ }
+ }
+
+ None
+ }
+
+ fn set_elapsed(&mut self, when: u64) {
+ assert!(
+ self.elapsed <= when,
+ "elapsed={:?}; when={:?}",
+ self.elapsed,
+ when
+ );
+
+ if when > self.elapsed {
+ self.elapsed = when;
+ }
+ }
+
+ fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> {
+ self.levels[expiration.level].pop_entry_slot(expiration.slot, store)
+ }
+
+ fn level_for(&self, when: u64) -> usize {
+ level_for(self.elapsed, when)
+ }
+}
+
+fn level_for(elapsed: u64, when: u64) -> usize {
+ const SLOT_MASK: u64 = (1 << 6) - 1;
+
+ // Mask in the trailing bits ignored by the level calculation in order to cap
+ // the possible leading zeros
+ let masked = elapsed ^ when | SLOT_MASK;
+
+ let leading_zeros = masked.leading_zeros() as usize;
+ let significant = 63 - leading_zeros;
+ significant / 6
+}
+
+#[cfg(all(test, not(loom)))]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_level_for() {
+ for pos in 0..64 {
+ assert_eq!(
+ 0,
+ level_for(0, pos),
+ "level_for({}) -- binary = {:b}",
+ pos,
+ pos
+ );
+ }
+
+ for level in 1..5 {
+ for pos in level..64 {
+ let a = pos * 64_usize.pow(level as u32);
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+
+ if pos > level {
+ let a = a - 1;
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+ }
+
+ if pos < 64 {
+ let a = a + 1;
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+ }
+ }
+ }
+ }
+}
diff --git a/third_party/rust/tokio-util/src/time/wheel/stack.rs b/third_party/rust/tokio-util/src/time/wheel/stack.rs
new file mode 100644
index 0000000000..c87adcafda
--- /dev/null
+++ b/third_party/rust/tokio-util/src/time/wheel/stack.rs
@@ -0,0 +1,28 @@
+use std::borrow::Borrow;
+use std::cmp::Eq;
+use std::hash::Hash;
+
+/// Abstracts the stack operations needed to track timeouts.
+pub(crate) trait Stack: Default {
+ /// Type of the item stored in the stack
+ type Owned: Borrow<Self::Borrowed>;
+
+ /// Borrowed item
+ type Borrowed: Eq + Hash;
+
+ /// Item storage, this allows a slab to be used instead of just the heap
+ type Store;
+
+ /// Returns `true` if the stack is empty
+ fn is_empty(&self) -> bool;
+
+ /// Push an item onto the stack
+ fn push(&mut self, item: Self::Owned, store: &mut Self::Store);
+
+ /// Pop an item from the stack
+ fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>;
+
+ fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store);
+
+ fn when(item: &Self::Borrowed, store: &Self::Store) -> u64;
+}
diff --git a/third_party/rust/tokio-util/src/udp/frame.rs b/third_party/rust/tokio-util/src/udp/frame.rs
new file mode 100644
index 0000000000..d900fd7691
--- /dev/null
+++ b/third_party/rust/tokio-util/src/udp/frame.rs
@@ -0,0 +1,245 @@
+use crate::codec::{Decoder, Encoder};
+
+use futures_core::Stream;
+use tokio::{io::ReadBuf, net::UdpSocket};
+
+use bytes::{BufMut, BytesMut};
+use futures_core::ready;
+use futures_sink::Sink;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{
+ borrow::Borrow,
+ net::{Ipv4Addr, SocketAddr, SocketAddrV4},
+};
+use std::{io, mem::MaybeUninit};
+
+/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
+/// the `Encoder` and `Decoder` traits to encode and decode frames.
+///
+/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
+/// batch these into meaningful chunks, called "frames". This method layers
+/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
+/// handle encoding and decoding of messages frames. Note that the incoming and
+/// outgoing frame types may be distinct.
+///
+/// This function returns a *single* object that is both [`Stream`] and [`Sink`];
+/// grouping this into a single object is often useful for layering things which
+/// require both read and write access to the underlying object.
+///
+/// If you want to work more directly with the streams and sink, consider
+/// calling [`split`] on the `UdpFramed` returned by this method, which will break
+/// them into separate objects, allowing them to interact more easily.
+///
+/// [`Stream`]: futures_core::Stream
+/// [`Sink`]: futures_sink::Sink
+/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+#[must_use = "sinks do nothing unless polled"]
+#[derive(Debug)]
+pub struct UdpFramed<C, T = UdpSocket> {
+ socket: T,
+ codec: C,
+ rd: BytesMut,
+ wr: BytesMut,
+ out_addr: SocketAddr,
+ flushed: bool,
+ is_readable: bool,
+ current_addr: Option<SocketAddr>,
+}
+
+const INITIAL_RD_CAPACITY: usize = 64 * 1024;
+const INITIAL_WR_CAPACITY: usize = 8 * 1024;
+
+impl<C, T> Unpin for UdpFramed<C, T> {}
+
+impl<C, T> Stream for UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+ C: Decoder,
+{
+ type Item = Result<(C::Item, SocketAddr), C::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let pin = self.get_mut();
+
+ pin.rd.reserve(INITIAL_RD_CAPACITY);
+
+ loop {
+ // Are there still bytes left in the read buffer to decode?
+ if pin.is_readable {
+ if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
+ let current_addr = pin
+ .current_addr
+ .expect("will always be set before this line is called");
+
+ return Poll::Ready(Some(Ok((frame, current_addr))));
+ }
+
+ // if this line has been reached then decode has returned `None`.
+ pin.is_readable = false;
+ pin.rd.clear();
+ }
+
+ // We're out of data. Try and fetch more data to decode
+ let addr = unsafe {
+ // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be
+ // writing to it via `poll_recv_from` and therefore initializing the memory.
+ let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]);
+ let mut read = ReadBuf::uninit(buf);
+ let ptr = read.filled().as_ptr();
+ let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
+
+ assert_eq!(ptr, read.filled().as_ptr());
+ let addr = res?;
+ pin.rd.advance_mut(read.filled().len());
+ addr
+ };
+
+ pin.current_addr = Some(addr);
+ pin.is_readable = true;
+ }
+ }
+}
+
+impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+ C: Encoder<I>,
+{
+ type Error = C::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if !self.flushed {
+ match self.poll_flush(cx)? {
+ Poll::Ready(()) => {}
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
+ let (frame, out_addr) = item;
+
+ let pin = self.get_mut();
+
+ pin.codec.encode(frame, &mut pin.wr)?;
+ pin.out_addr = out_addr;
+ pin.flushed = false;
+
+ Ok(())
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if self.flushed {
+ return Poll::Ready(Ok(()));
+ }
+
+ let Self {
+ ref socket,
+ ref mut out_addr,
+ ref mut wr,
+ ..
+ } = *self;
+
+ let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
+
+ let wrote_all = n == self.wr.len();
+ self.wr.clear();
+ self.flushed = true;
+
+ let res = if wrote_all {
+ Ok(())
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to write entire datagram to socket",
+ )
+ .into())
+ };
+
+ Poll::Ready(res)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.poll_flush(cx))?;
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<C, T> UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+{
+ /// Create a new `UdpFramed` backed by the given socket and codec.
+ ///
+ /// See struct level documentation for more details.
+ pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
+ Self {
+ socket,
+ codec,
+ out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
+ rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
+ wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
+ flushed: true,
+ is_readable: false,
+ current_addr: None,
+ }
+ }
+
+ /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
+ ///
+ /// # Note
+ ///
+ /// Care should be taken to not tamper with the underlying stream of data
+ /// coming in as it may corrupt the stream of frames otherwise being worked
+ /// with.
+ pub fn get_ref(&self) -> &T {
+ &self.socket
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
+ ///
+ /// # Note
+ ///
+ /// Care should be taken to not tamper with the underlying stream of data
+ /// coming in as it may corrupt the stream of frames otherwise being worked
+ /// with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.socket
+ }
+
+ /// Returns a reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec(&self) -> &C {
+ &self.codec
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `UdpFramed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_mut(&mut self) -> &mut C {
+ &mut self.codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.rd
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.rd
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream.
+ pub fn into_inner(self) -> T {
+ self.socket
+ }
+}
diff --git a/third_party/rust/tokio-util/src/udp/mod.rs b/third_party/rust/tokio-util/src/udp/mod.rs
new file mode 100644
index 0000000000..f88ea030aa
--- /dev/null
+++ b/third_party/rust/tokio-util/src/udp/mod.rs
@@ -0,0 +1,4 @@
+//! UDP framing
+
+mod frame;
+pub use frame::UdpFramed;