diff options
Diffstat (limited to 'third_party/rust/tokio-util')
24 files changed, 4369 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/.cargo-checksum.json b/third_party/rust/tokio-util/.cargo-checksum.json new file mode 100644 index 0000000000..527e194c06 --- /dev/null +++ b/third_party/rust/tokio-util/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"CHANGELOG.md":"da9a61af4bf03f2ca997231b5551f506fd973ef20795ea91dc295d84f81c3527","Cargo.toml":"d892a316aed002bdf3fe230275f13aec5068c899d61b16722ecdd316bd9ac429","LICENSE":"898b1ae9821e98daf8964c8d6c7f61641f5f5aa78ad500020771c0939ee0dea1","README.md":"e895cbba8345655607ebd830f0101e66d2e6e9287ad3b9ddf697cca230738053","src/cfg.rs":"29a5a0b96eacc46982076673a3e35b0579348db0d29c0a4d41257bb43477fd96","src/codec/bytes_codec.rs":"deaa207773963f05f11c63d0fe5815356c3a931e4c42522334cbee54bc9950aa","src/codec/decoder.rs":"20e31de370bc334709d9e9fee687a11781ad56a7ef0c8b848f961fd62ac680f0","src/codec/encoder.rs":"6b145bce207cefd1766a312e3d91974b7a8d4fbabdb5b56482e91d319d84c964","src/codec/framed.rs":"cc66cb4629b305dc57217d3651273eb238fc1ce7ce382b3f9547b9dd671afa05","src/codec/framed_read.rs":"111063a9770da2fccbb22dcd9d8d2745dbcfcfc275347b6bed64913b765f0929","src/codec/framed_write.rs":"bd48f0b163cf144ef0561b0046fa7f0902058af61a0584f20a7c96e0046c1691","src/codec/length_delimited.rs":"88c53b174a579bb2b2aba20050b7e2a014f0174d2dd3494fbc9c26476dcb4a6e","src/codec/lines_codec.rs":"7ef2002d27f490c3e4fee458158283589c65cee08abcd6f3fb6a9bede22a0b26","src/codec/mod.rs":"b9dd36bc37615ef25d16720e1a4d467c31ab53f72c5df414e3a8a7af387ea361","src/lib.rs":"ca00e2725fc9df394fb4a9e9dd43016c31613dc337b8d8e7ebc622a975d7d08a","src/udp/frame.rs":"50538c8397ff85bba55a6336d33598a6ecc0995eab533d360745bbc672d3d076","src/udp/mod.rs":"699abcf4e12d8180a9e3f2b986edee6e2989ca995ec07c8be1d62d2c263c4a1e","tests/codecs.rs":"d9d3d8519306ace00d4080fa98cae985e199674db5bca388f2eb17a748bf6e5d","tests/framed.rs":"87e75368a922f96515cfefc1146ce4c22e421775778b84eca0e68c0fe5f0eda7","tests/framed_read.rs":"6ce94a7a6cd9fd879a52689c629c0d8c056561a465f7a121f0ccacef918351b6","tests/framed_write.rs":"38ee68e620f829943164db0fd012644035dad8ef71c44a7ae3858f0ad64a050c","tests/length_delimited.rs":"3bad5a560b7bf081ebc03b34b190db58035b30138ba2f7c5cc64ae8fd2507700","tests/udp.rs":"428edc3a33438bcea3e08db58b85e9d1869b192872bf7f92fdaff314ff134d68"},"package":"571da51182ec208780505a32528fc5512a8fe1443ab960b3f2f3ef093cd16930"}
\ No newline at end of file diff --git a/third_party/rust/tokio-util/CHANGELOG.md b/third_party/rust/tokio-util/CHANGELOG.md new file mode 100644 index 0000000000..48022e3435 --- /dev/null +++ b/third_party/rust/tokio-util/CHANGELOG.md @@ -0,0 +1,3 @@ +# 0.2.0 (November 26, 2019) + +- Initial release diff --git a/third_party/rust/tokio-util/Cargo.toml b/third_party/rust/tokio-util/Cargo.toml new file mode 100644 index 0000000000..954d45552f --- /dev/null +++ b/third_party/rust/tokio-util/Cargo.toml @@ -0,0 +1,58 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies +# +# If you believe there's an error in this file please file an +# issue against the rust-lang/cargo repository. If you're +# editing this file be aware that the upstream Cargo.toml +# will likely look very different (and much more reasonable) + +[package] +edition = "2018" +name = "tokio-util" +version = "0.2.0" +authors = ["Tokio Contributors <team@tokio.rs>"] +description = "Additional utilities for working with Tokio.\n" +homepage = "https://tokio.rs" +documentation = "https://docs.rs/tokio-util/0.2.0/tokio_util" +categories = ["asynchronous"] +license = "MIT" +repository = "https://github.com/tokio-rs/tokio" +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] +[dependencies.bytes] +version = "0.5.0" + +[dependencies.futures-core] +version = "0.3.0" + +[dependencies.futures-sink] +version = "0.3.0" + +[dependencies.log] +version = "0.4" + +[dependencies.pin-project-lite] +version = "0.1.1" + +[dependencies.tokio] +version = "0.2.0" +[dev-dependencies.futures] +version = "0.3.0" + +[dev-dependencies.tokio] +version = "0.2.0" +features = ["full"] + +[dev-dependencies.tokio-test] +version = "0.2.0" + +[features] +codec = [] +default = [] +full = ["codec", "udp"] +udp = ["tokio/udp"] diff --git a/third_party/rust/tokio-util/LICENSE b/third_party/rust/tokio-util/LICENSE new file mode 100644 index 0000000000..cdb28b4b56 --- /dev/null +++ b/third_party/rust/tokio-util/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/third_party/rust/tokio-util/README.md b/third_party/rust/tokio-util/README.md new file mode 100644 index 0000000000..11b2b1841a --- /dev/null +++ b/third_party/rust/tokio-util/README.md @@ -0,0 +1,13 @@ +# tokio-util + +Utilities for encoding and decoding frames. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tokio by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/third_party/rust/tokio-util/src/cfg.rs b/third_party/rust/tokio-util/src/cfg.rs new file mode 100644 index 0000000000..13fabd3638 --- /dev/null +++ b/third_party/rust/tokio-util/src/cfg.rs @@ -0,0 +1,19 @@ +macro_rules! cfg_codec { + ($($item:item)*) => { + $( + #[cfg(feature = "codec")] + #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] + $item + )* + } +} + +macro_rules! cfg_udp { + ($($item:item)*) => { + $( + #[cfg(all(feature = "udp", feature = "codec"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "udp", feature = "codec"))))] + $item + )* + } +} diff --git a/third_party/rust/tokio-util/src/codec/bytes_codec.rs b/third_party/rust/tokio-util/src/codec/bytes_codec.rs new file mode 100644 index 0000000000..a7d424e9e6 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/bytes_codec.rs @@ -0,0 +1,41 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{BufMut, Bytes, BytesMut}; +use std::io; + +/// A simple `Codec` implementation that just ships bytes around. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] +pub struct BytesCodec(()); + +impl BytesCodec { + /// Creates a new `BytesCodec` for shipping around raw bytes. + pub fn new() -> BytesCodec { + BytesCodec(()) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> { + if !buf.is_empty() { + let len = buf.len(); + Ok(Some(buf.split_to(len))) + } else { + Ok(None) + } + } +} + +impl Encoder for BytesCodec { + type Item = Bytes; + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} diff --git a/third_party/rust/tokio-util/src/codec/decoder.rs b/third_party/rust/tokio-util/src/codec/decoder.rs new file mode 100644 index 0000000000..dfe5f8ee1a --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/decoder.rs @@ -0,0 +1,154 @@ +use crate::codec::encoder::Encoder; +use crate::codec::Framed; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use std::io; + +/// Decoding of frames via buffers. +/// +/// This trait is used when constructing an instance of `Framed` or +/// `FramedRead`. An implementation of `Decoder` takes a byte stream that has +/// already been buffered in `src` and decodes the data into a stream of +/// `Self::Item` frames. +/// +/// Implementations are able to track state on `self`, which enables +/// implementing stateful streaming parsers. In many cases, though, this type +/// will simply be a unit struct (e.g. `struct HttpDecoder`). +pub trait Decoder { + /// The type of decoded frames. + type Item; + + /// The type of unrecoverable frame decoding errors. + /// + /// If an individual message is ill-formed but can be ignored without + /// interfering with the processing of future messages, it may be more + /// useful to report the failure as an `Item`. + /// + /// `From<io::Error>` is required in the interest of making `Error` suitable + /// for returning directly from a `FramedRead`, and to enable the default + /// implementation of `decode_eof` to yield an `io::Error` when the decoder + /// fails to consume all available data. + /// + /// Note that implementors of this trait can simply indicate `type Error = + /// io::Error` to use I/O errors as this type. + type Error: From<io::Error>; + + /// Attempts to decode a frame from the provided buffer of bytes. + /// + /// This method is called by `FramedRead` whenever bytes are ready to be + /// parsed. The provided buffer of bytes is what's been read so far, and + /// this instance of `Decode` can determine whether an entire frame is in + /// the buffer and is ready to be returned. + /// + /// If an entire frame is available, then this instance will remove those + /// bytes from the buffer provided and return them as a decoded + /// frame. Note that removing bytes from the provided buffer doesn't always + /// necessarily copy the bytes, so this should be an efficient operation in + /// most circumstances. + /// + /// If the bytes look valid, but a frame isn't fully available yet, then + /// `Ok(None)` is returned. This indicates to the `Framed` instance that + /// it needs to read some more bytes before calling this method again. + /// + /// Note that the bytes provided may be empty. If a previous call to + /// `decode` consumed all the bytes in the buffer then `decode` will be + /// called again until it returns `Ok(None)`, indicating that more bytes need to + /// be read. + /// + /// Finally, if the bytes in the buffer are malformed then an error is + /// returned indicating why. This informs `Framed` that the stream is now + /// corrupt and should be terminated. + /// + /// # Buffer management + /// + /// Before returning from the function, implementations should ensure that + /// the buffer has appropriate capacity in anticipation of future calls to + /// `decode`. Failing to do so leads to inefficiency. + /// + /// For example, if frames have a fixed length, or if the length of the + /// current frame is known from a header, a possible buffer management + /// strategy is: + /// + /// ```no_run + /// # use std::io; + /// # + /// # use bytes::BytesMut; + /// # use tokio_util::codec::Decoder; + /// # + /// # struct MyCodec; + /// # + /// impl Decoder for MyCodec { + /// // ... + /// # type Item = BytesMut; + /// # type Error = io::Error; + /// + /// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + /// // ... + /// + /// // Reserve enough to complete decoding of the current frame. + /// let current_frame_len: usize = 1000; // Example. + /// // And to start decoding the next frame. + /// let next_frame_header_len: usize = 10; // Example. + /// src.reserve(current_frame_len + next_frame_header_len); + /// + /// return Ok(None); + /// } + /// } + /// ``` + /// + /// An optimal buffer management strategy minimizes reallocations and + /// over-allocations. + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>; + + /// A default method available to be called when there are no more bytes + /// available to be read from the underlying I/O. + /// + /// This method defaults to calling `decode` and returns an error if + /// `Ok(None)` is returned while there is unconsumed data in `buf`. + /// Typically this doesn't need to be implemented unless the framing + /// protocol differs near the end of the stream. + /// + /// Note that the `buf` argument may be empty. If a previous call to + /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be + /// called again until it returns `None`, indicating that there are no more + /// frames to yield. This behavior enables returning finalization frames + /// that may not be based on inbound data. + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } + + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self> + where + Self: Encoder + Sized, + { + Framed::new(io, self) + } +} diff --git a/third_party/rust/tokio-util/src/codec/encoder.rs b/third_party/rust/tokio-util/src/codec/encoder.rs new file mode 100644 index 0000000000..76fa9dbae0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/encoder.rs @@ -0,0 +1,22 @@ +use bytes::BytesMut; +use std::io; + +/// Trait of helper objects to write out messages as bytes, for use with +/// `FramedWrite`. +pub trait Encoder { + /// The type of items consumed by the `Encoder` + type Item; + + /// The type of encoding errors. + /// + /// `FramedWrite` requires `Encoder`s errors to implement `From<io::Error>` + /// in the interest letting it return `Error`s directly. + type Error: From<io::Error>; + + /// Encodes a frame into the buffer provided. + /// + /// This method will encode `item` into the byte buffer provided by `dst`. + /// The `dst` provided is an internal buffer of the `Framed` instance and + /// will be written out when possible. + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/third_party/rust/tokio-util/src/codec/framed.rs b/third_party/rust/tokio-util/src/codec/framed.rs new file mode 100644 index 0000000000..0c5ef9f6fa --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed.rs @@ -0,0 +1,371 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; +use crate::codec::framed_read::{framed_read2, framed_read2_with_buffer, FramedRead2}; +use crate::codec::framed_write::{framed_write2, framed_write2_with_buffer, FramedWrite2}; + +use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_core::Stream; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io::{self, BufRead, Read, Write}; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A unified `Stream` and `Sink` interface to an underlying I/O object, using + /// the `Encoder` and `Decoder` traits to encode and decode frames. + /// + /// You can create a `Framed` instance by using the `AsyncRead::framed` adapter. + pub struct Framed<T, U> { + #[pin] + inner: FramedRead2<FramedWrite2<Fuse<T, U>>>, + } +} + +pin_project! { + pub(crate) struct Fuse<T, U> { + #[pin] + pub(crate) io: T, + pub(crate) codec: U, + } +} + +/// Abstracts over `FramedRead2` being either `FramedRead2<FramedWrite2<Fuse<T, U>>>` or +/// `FramedRead2<Fuse<T, U>>` and lets the io and codec parts be extracted in either case. +pub(crate) trait ProjectFuse { + type Io; + type Codec; + + fn project(self: Pin<&mut Self>) -> Fuse<Pin<&mut Self::Io>, &mut Self::Codec>; +} + +impl<T, U> ProjectFuse for Fuse<T, U> { + type Io = T; + type Codec = U; + + fn project(self: Pin<&mut Self>) -> Fuse<Pin<&mut Self::Io>, &mut Self::Codec> { + let self_ = self.project(); + Fuse { + io: self_.io, + codec: self_.codec, + } + } +} + +impl<T, U> Framed<T, U> +where + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, +{ + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + pub fn new(inner: T, codec: U) -> Framed<T, U> { + Framed { + inner: framed_read2(framed_write2(Fuse { io: inner, codec })), + } + } +} + +impl<T, U> Framed<T, U> { + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// This objects takes a stream and a readbuffer and a writebuffer. These field + /// can be obtained from an existing `Framed` with the `into_parts` method. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> { + Framed { + inner: framed_read2_with_buffer( + framed_write2_with_buffer( + Fuse { + io: parts.io, + codec: parts.codec, + }, + parts.write_buf, + ), + parts.read_buf, + ), + } + } + + /// Returns a reference to the underlying I/O stream wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.get_ref().get_ref().io + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.get_mut().get_mut().io + } + + /// Returns a reference to the underlying codec wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &U { + &self.inner.get_ref().get_ref().codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut U { + &mut self.inner.get_mut().get_mut().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + self.inner.buffer() + } + + /// Consumes the `Frame`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.into_inner().into_inner().io + } + + /// Consumes the `Frame`, returning its underlying I/O stream, the buffer + /// with unprocessed data, and the codec. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_parts(self) -> FramedParts<T, U> { + let (inner, read_buf) = self.inner.into_parts(); + let (inner, write_buf) = inner.into_parts(); + + FramedParts { + io: inner.io, + codec: inner.codec, + read_buf, + write_buf, + _priv: (), + } + } +} + +impl<T, U> Stream for Framed<T, U> +where + T: AsyncRead, + U: Decoder, +{ + type Item = Result<U::Item, U::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +impl<T, I, U> Sink<I> for Framed<T, U> +where + T: AsyncWrite, + U: Encoder<Item = I>, + U::Error: From<io::Error>, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.get_pin_mut().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.get_pin_mut().start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.get_pin_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.get_pin_mut().poll_close(cx) + } +} + +impl<T, U> fmt::Debug for Framed<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Framed") + .field("io", &self.inner.get_ref().get_ref().io) + .field("codec", &self.inner.get_ref().get_ref().codec) + .finish() + } +} + +// ===== impl Fuse ===== + +impl<T: Read, U> Read for Fuse<T, U> { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.io.read(dst) + } +} + +impl<T: BufRead, U> BufRead for Fuse<T, U> { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + self.io.fill_buf() + } + + fn consume(&mut self, amt: usize) { + self.io.consume(amt) + } +} + +impl<T: AsyncRead, U> AsyncRead for Fuse<T, U> { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<Result<usize, io::Error>> { + self.project().io.poll_read(cx, buf) + } +} + +impl<T: AsyncBufRead, U> AsyncBufRead for Fuse<T, U> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.project().io.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().io.consume(amt) + } +} + +impl<T: Write, U> Write for Fuse<T, U> { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + self.io.write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl<T: AsyncWrite, U> AsyncWrite for Fuse<T, U> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + self.project().io.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.project().io.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.project().io.poll_shutdown(cx) + } +} + +impl<T, U: Decoder> Decoder for Fuse<T, U> { + type Item = U::Item; + type Error = U::Error; + + fn decode(&mut self, buffer: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + self.codec.decode(buffer) + } + + fn decode_eof(&mut self, buffer: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + self.codec.decode_eof(buffer) + } +} + +impl<T, U: Encoder> Encoder for Fuse<T, U> { + type Item = U::Item; + type Error = U::Error; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.codec.encode(item, dst) + } +} + +/// `FramedParts` contains an export of the data of a Framed transport. +/// It can be used to construct a new `Framed` with a different codec. +/// It contains all current buffers and the inner transport. +#[derive(Debug)] +pub struct FramedParts<T, U> { + /// The inner transport used to read bytes to and write bytes to + pub io: T, + + /// The codec + pub codec: U, + + /// The buffer with read but unprocessed data. + pub read_buf: BytesMut, + + /// A buffer with unprocessed data which are not written yet. + pub write_buf: BytesMut, + + /// This private field allows us to add additional fields in the future in a + /// backwards compatible way. + _priv: (), +} + +impl<T, U> FramedParts<T, U> { + /// Create a new, default, `FramedParts` + pub fn new(io: T, codec: U) -> FramedParts<T, U> { + FramedParts { + io, + codec, + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + _priv: (), + } + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_read.rs b/third_party/rust/tokio-util/src/codec/framed_read.rs new file mode 100644 index 0000000000..bd1f625b0c --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_read.rs @@ -0,0 +1,288 @@ +use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::Decoder; + +use tokio::io::AsyncRead; + +use bytes::BytesMut; +use futures_core::Stream; +use futures_sink::Sink; +use log::trace; +use pin_project_lite::pin_project; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A `Stream` of messages decoded from an `AsyncRead`. + pub struct FramedRead<T, D> { + #[pin] + inner: FramedRead2<Fuse<T, D>>, + } +} + +pin_project! { + pub(crate) struct FramedRead2<T> { + #[pin] + inner: T, + eof: bool, + is_readable: bool, + buffer: BytesMut, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; + +// ===== impl FramedRead ===== + +impl<T, D> FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + /// Creates a new `FramedRead` with the given `decoder`. + pub fn new(inner: T, decoder: D) -> FramedRead<T, D> { + FramedRead { + inner: framed_read2(Fuse { + io: inner, + codec: decoder, + }), + } + } +} + +impl<T, D> FramedRead<T, D> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner.io + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner.io + } + + /// Consumes the `FramedRead`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner.io + } + + /// Returns a reference to the underlying decoder. + pub fn decoder(&self) -> &D { + &self.inner.inner.codec + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_mut(&mut self) -> &mut D { + &mut self.inner.inner.codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.buffer + } +} + +impl<T, D> Stream for FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + type Item = Result<D::Item, D::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl<T, I, D> Sink<I> for FramedRead<T, D> +where + T: Sink<I>, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project() + .inner + .project() + .inner + .project() + .io + .poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project() + .inner + .project() + .inner + .project() + .io + .start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project() + .inner + .project() + .inner + .project() + .io + .poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project() + .inner + .project() + .inner + .project() + .io + .poll_close(cx) + } +} + +impl<T, D> fmt::Debug for FramedRead<T, D> +where + T: fmt::Debug, + D: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedRead") + .field("inner", &self.inner.inner.io) + .field("decoder", &self.inner.inner.codec) + .field("eof", &self.inner.eof) + .field("is_readable", &self.inner.is_readable) + .field("buffer", &self.inner.buffer) + .finish() + } +} + +// ===== impl FramedRead2 ===== + +pub(crate) fn framed_read2<T>(inner: T) -> FramedRead2<T> { + FramedRead2 { + inner, + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } +} + +pub(crate) fn framed_read2_with_buffer<T>(inner: T, mut buf: BytesMut) -> FramedRead2<T> { + if buf.capacity() < INITIAL_CAPACITY { + let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); + buf.reserve(bytes_to_reserve); + } + FramedRead2 { + inner, + eof: false, + is_readable: !buf.is_empty(), + buffer: buf, + } +} + +impl<T> FramedRead2<T> { + pub(crate) fn get_ref(&self) -> &T { + &self.inner + } + + pub(crate) fn into_inner(self) -> T { + self.inner + } + + pub(crate) fn into_parts(self) -> (T, BytesMut) { + (self.inner, self.buffer) + } + + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner + } + + pub(crate) fn buffer(&self) -> &BytesMut { + &self.buffer + } +} + +impl<T> Stream for FramedRead2<T> +where + T: ProjectFuse + AsyncRead, + T::Codec: Decoder, +{ + type Item = Result<<T::Codec as Decoder>::Item, <T::Codec as Decoder>::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let mut pinned = self.project(); + loop { + // Repeatedly call `decode` or `decode_eof` as long as it is + // "readable". Readable is defined as not having returned `None`. If + // the upstream has returned EOF, and the decoder is no longer + // readable, it can be assumed that the decoder will never become + // readable again, at which point the stream is terminated. + if *pinned.is_readable { + if *pinned.eof { + let frame = pinned + .inner + .as_mut() + .project() + .codec + .decode_eof(&mut pinned.buffer)?; + return Poll::Ready(frame.map(Ok)); + } + + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned + .inner + .as_mut() + .project() + .codec + .decode(&mut pinned.buffer)? + { + trace!("frame decoded from buffer"); + return Poll::Ready(Some(Ok(frame))); + } + + *pinned.is_readable = false; + } + + assert!(!*pinned.eof); + + // Otherwise, try to read more data and try again. Make sure we've + // got room for at least one byte to read to ensure that we don't + // get a spurious 0 that looks like EOF + pinned.buffer.reserve(1); + let bytect = match pinned + .inner + .as_mut() + .poll_read_buf(cx, &mut pinned.buffer)? + { + Poll::Ready(ct) => ct, + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + *pinned.eof = true; + } + + *pinned.is_readable = true; + } + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_write.rs b/third_party/rust/tokio-util/src/codec/framed_write.rs new file mode 100644 index 0000000000..9aed7ea3ce --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_write.rs @@ -0,0 +1,321 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; +use crate::codec::framed::{Fuse, ProjectFuse}; + +use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_core::{ready, Stream}; +use futures_sink::Sink; +use log::trace; +use pin_project_lite::pin_project; +use std::fmt; +use std::io::{self, BufRead, Read}; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A `Sink` of frames encoded to an `AsyncWrite`. + pub struct FramedWrite<T, E> { + #[pin] + inner: FramedWrite2<Fuse<T, E>>, + } +} + +pin_project! { + pub(crate) struct FramedWrite2<T> { + #[pin] + inner: T, + buffer: BytesMut, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +impl<T, E> FramedWrite<T, E> +where + T: AsyncWrite, + E: Encoder, +{ + /// Creates a new `FramedWrite` with the given `encoder`. + pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> { + FramedWrite { + inner: framed_write2(Fuse { + io: inner, + codec: encoder, + }), + } + } +} + +impl<T, E> FramedWrite<T, E> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner.io + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner.io + } + + /// Consumes the `FramedWrite`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner.io + } + + /// Returns a reference to the underlying decoder. + pub fn encoder(&self) -> &E { + &self.inner.inner.codec + } + + /// Returns a mutable reference to the underlying decoder. + pub fn encoder_mut(&mut self) -> &mut E { + &mut self.inner.inner.codec + } +} + +// This impl just defers to the underlying FramedWrite2 +impl<T, I, E> Sink<I> for FramedWrite<T, E> +where + T: AsyncWrite, + E: Encoder<Item = I>, + E::Error: From<io::Error>, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} + +impl<T, D> Stream for FramedWrite<T, D> +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project() + .inner + .project() + .inner + .project() + .io + .poll_next(cx) + } +} + +impl<T, U> fmt::Debug for FramedWrite<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedWrite") + .field("inner", &self.inner.get_ref().io) + .field("encoder", &self.inner.get_ref().codec) + .field("buffer", &self.inner.buffer) + .finish() + } +} + +// ===== impl FramedWrite2 ===== + +pub(crate) fn framed_write2<T>(inner: T) -> FramedWrite2<T> { + FramedWrite2 { + inner, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } +} + +pub(crate) fn framed_write2_with_buffer<T>(inner: T, mut buf: BytesMut) -> FramedWrite2<T> { + if buf.capacity() < INITIAL_CAPACITY { + let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); + buf.reserve(bytes_to_reserve); + } + FramedWrite2 { inner, buffer: buf } +} + +impl<T> FramedWrite2<T> { + pub(crate) fn get_ref(&self) -> &T { + &self.inner + } + + pub(crate) fn into_inner(self) -> T { + self.inner + } + + pub(crate) fn into_parts(self) -> (T, BytesMut) { + (self.inner, self.buffer) + } + + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl<I, T> Sink<I> for FramedWrite2<T> +where + T: ProjectFuse + AsyncWrite, + T::Codec: Encoder<Item = I>, +{ + type Error = <T::Codec as Encoder>::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's + // *still* over 8KiB, then apply backpressure (reject the send). + if self.buffer.len() >= BACKPRESSURE_BOUNDARY { + match self.as_mut().poll_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => (), + }; + + if self.buffer.len() >= BACKPRESSURE_BOUNDARY { + return Poll::Pending; + } + } + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let mut pinned = self.project(); + pinned + .inner + .project() + .codec + .encode(item, &mut pinned.buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.buffer.is_empty() { + trace!("writing; remaining={}", pinned.buffer.len()); + + let buf = &pinned.buffer; + let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + + // TODO: Add a way to `bytes` to do this w/o returning the drained data. + let _ = pinned.buffer.split_to(n); + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} + +impl<T: Decoder> Decoder for FramedWrite2<T> { + type Item = T::Item; + type Error = T::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> { + self.inner.decode(src) + } + + fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> { + self.inner.decode_eof(src) + } +} + +impl<T: Read> Read for FramedWrite2<T> { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner.read(dst) + } +} + +impl<T: BufRead> BufRead for FramedWrite2<T> { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + self.inner.fill_buf() + } + + fn consume(&mut self, amt: usize) { + self.inner.consume(amt) + } +} + +impl<T: AsyncRead> AsyncRead for FramedWrite2<T> { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<Result<usize, io::Error>> { + self.project().inner.poll_read(cx, buf) + } +} + +impl<T: AsyncBufRead> AsyncBufRead for FramedWrite2<T> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.project().inner.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().inner.consume(amt) + } +} + +impl<T> ProjectFuse for FramedWrite2<T> +where + T: ProjectFuse, +{ + type Io = T::Io; + type Codec = T::Codec; + + fn project(self: Pin<&mut Self>) -> Fuse<Pin<&mut Self::Io>, &mut Self::Codec> { + self.project().inner.project() + } +} diff --git a/third_party/rust/tokio-util/src/codec/length_delimited.rs b/third_party/rust/tokio-util/src/codec/length_delimited.rs new file mode 100644 index 0000000000..01ba2aec05 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/length_delimited.rs @@ -0,0 +1,963 @@ +//! Frame a stream of bytes based on a length prefix +//! +//! Many protocols delimit their frames by prefacing frame data with a +//! frame head that specifies the length of the frame. The +//! `length_delimited` module provides utilities for handling the length +//! based framing. This allows the consumer to work with entire frames +//! without having to worry about buffering or other framing logic. +//! +//! # Getting started +//! +//! If implementing a protocol from scratch, using length delimited framing +//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will +//! return a length delimited codec using default configuration values. +//! This can then be used to construct a framer to adapt a full-duplex +//! byte stream into a stream of frames. +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! fn bind_transport<T: AsyncRead + AsyncWrite>(io: T) +//! -> Framed<T, LengthDelimitedCodec> +//! { +//! Framed::new(io, LengthDelimitedCodec::new()) +//! } +//! # pub fn main() {} +//! ``` +//! +//! The returned transport implements `Sink + Stream` for `BytesMut`. It +//! encodes the frame with a big-endian `u32` header denoting the frame +//! payload length: +//! +//! ```text +//! +----------+--------------------------------+ +//! | len: u32 | frame payload | +//! +----------+--------------------------------+ +//! ``` +//! +//! Specifically, given the following: +//! +//! ``` +//! use tokio::prelude::*; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! use futures::SinkExt; +//! use bytes::Bytes; +//! +//! async fn write_frame<T>(io: T) -> Result<(), Box<dyn std::error::Error>> +//! where +//! T: AsyncRead + AsyncWrite + Unpin, +//! { +//! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); +//! let frame = Bytes::from("hello world"); +//! +//! transport.send(frame).await?; +//! Ok(()) +//! } +//! ``` +//! +//! The encoded frame will look like this: +//! +//! ```text +//! +---- len: u32 ----+---- data ----+ +//! | \x00\x00\x00\x0b | hello world | +//! +------------------+--------------+ +//! ``` +//! +//! # Decoding +//! +//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], +//! such that each yielded [`BytesMut`] value contains the contents of an +//! entire frame. There are many configuration parameters enabling +//! [`FramedRead`] to handle a wide range of protocols. Here are some +//! examples that will cover the various options at a high level. +//! +//! ## Example 1 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(2) +//! .length_adjustment(0) // default value +//! .num_skip(0) // Do not strip frame header +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! The value of the length field is 11 (`\x0B`) which represents the length +//! of the payload, `hello world`. By default, [`FramedRead`] assumes that +//! the length field represents the number of bytes that **follows** the +//! length field. Thus, the entire frame has a length of 13: 2 bytes for the +//! frame head + 11 bytes for the payload. +//! +//! ## Example 2 +//! +//! The following will parse a `u16` length field at offset 0, omitting the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(2) +//! .length_adjustment(0) // default value +//! // `num_skip` is not needed, the default is to skip +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +--- Payload ---+ +//! | \x00\x0B | Hello world | --> | Hello world | +//! +----------+---------------+ +---------------+ +//! ``` +//! +//! This is similar to the first example, the only difference is that the +//! frame head is **not** included in the yielded `BytesMut` value. +//! +//! ## Example 3 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. In this case, the length field +//! **includes** the frame head length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(2) +//! .length_adjustment(-2) // size of head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! In most cases, the length field represents the length of the payload +//! only, as shown in the previous examples. However, in some protocols the +//! length field represents the length of the whole frame, including the +//! head. In such cases, we specify a negative `length_adjustment` to adjust +//! the value provided in the frame head to represent the payload length. +//! +//! ## Example 4 +//! +//! The following will parse a 3 byte length field at offset 0 in a 5 byte +//! frame head, including the frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(2) // remaining head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! +//! DECODED +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! ``` +//! +//! A more advanced example that shows a case where there is extra frame +//! head data between the length field and the payload. In such cases, it is +//! usually desirable to include the frame head as part of the yielded +//! `BytesMut`. This lets consumers of the length delimited framer to +//! process the frame head as needed. +//! +//! The positive `length_adjustment` value lets `FramedRead` factor in the +//! additional head into the frame length calculation. +//! +//! ## Example 5 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_length(2) +//! .length_adjustment(1) // length of hdr2 +//! .num_skip(3) // length of hdr1 + LEN +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0B | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! The length field is situated in the middle of the frame head. In this +//! case, the first byte in the frame head could be a version or some other +//! identifier that is not needed for processing. On the other hand, the +//! second half of the head is needed. +//! +//! `length_field_offset` indicates how many bytes to skip before starting +//! to read the length field. `length_adjustment` is the number of bytes to +//! skip starting at the end of the length field. In this case, it is the +//! second half of the head. +//! +//! ## Example 6 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. In this case, the length field **includes** the frame head +//! length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_length(2) +//! .length_adjustment(-3) // length of hdr1 + LEN, negative +//! .num_skip(3) +//! .new_read(io); +//! # } +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0F | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! Similar to the example above, the difference is that the length field +//! represents the length of the entire frame instead of just the payload. +//! The length of `hdr1` and `len` must be counted in `length_adjustment`. +//! Note that the length of `hdr2` does **not** need to be explicitly set +//! anywhere because it already is factored into the total frame length that +//! is read from the byte stream. +//! +//! # Encoding +//! +//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], +//! such that each submitted [`BytesMut`] is prefaced by a length field. +//! There are fewer configuration options than [`FramedRead`]. Given +//! protocols that have more complex frame heads, an encoder should probably +//! be written by hand using [`Encoder`]. +//! +//! Here is a simple example, given a `FramedWrite` with the following +//! configuration: +//! +//! ``` +//! # use tokio::io::AsyncWrite; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn write_frame<T: AsyncWrite>(io: T) { +//! # let _ = +//! LengthDelimitedCodec::builder() +//! .length_field_length(2) +//! .new_write(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! A payload of `hello world` will be encoded as: +//! +//! ```text +//! +- len: u16 -+---- data ----+ +//! | \x00\x0b | hello world | +//! +------------+--------------+ +//! ``` +//! +//! [`LengthDelimitedCodec::new()`]: struct.LengthDelimitedCodec.html#method.new +//! [`FramedRead`]: struct.FramedRead.html +//! [`FramedWrite`]: struct.FramedWrite.html +//! [`AsyncRead`]: ../../trait.AsyncRead.html +//! [`AsyncWrite`]: ../../trait.AsyncWrite.html +//! [`Encoder`]: ../trait.Encoder.html +//! [`BytesMut`]: https://docs.rs/bytes/0.4/bytes/struct.BytesMut.html + +use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::error::Error as StdError; +use std::io::{self, Cursor}; +use std::{cmp, fmt}; + +/// Configure length delimited `LengthDelimitedCodec`s. +/// +/// `Builder` enables constructing configured length delimited codecs. Note +/// that not all configuration settings apply to both encoding and decoding. See +/// the documentation for specific methods for more detail. +#[derive(Debug, Clone, Copy)] +pub struct Builder { + // Maximum frame length + max_frame_len: usize, + + // Number of bytes representing the field length + length_field_len: usize, + + // Number of bytes in the header before the length field + length_field_offset: usize, + + // Adjust the length specified in the header field by this amount + length_adjustment: isize, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: Option<usize>, + + // Length field byte order (little or big endian) + length_field_is_big_endian: bool, +} + +/// An error when the number of bytes read is more than max frame length. +pub struct LengthDelimitedCodecError { + _priv: (), +} + +/// A codec for frames delimited by a frame head specifying their lengths. +/// +/// This allows the consumer to work with entire frames without having to worry +/// about buffering or other framing logic. +/// +/// See [module level] documentation for more detail. +/// +/// [module level]: index.html +#[derive(Debug)] +pub struct LengthDelimitedCodec { + // Configuration values + builder: Builder, + + // Read state + state: DecodeState, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +// ===== impl LengthDelimitedCodec ====== + +impl LengthDelimitedCodec { + /// Creates a new `LengthDelimitedCodec` with the default configuration values. + pub fn new() -> Self { + Self { + builder: Builder::new(), + state: DecodeState::Head, + } + } + + /// Creates a new length delimited codec builder with default configuration + /// values. + pub fn builder() -> Builder { + Builder::new() + } + + /// Returns the current max frame setting + /// + /// This is the largest size this codec will accept from the wire. Larger + /// frames will be rejected. + pub fn max_frame_length(&self) -> usize { + self.builder.max_frame_len + } + + /// Updates the max frame setting. + /// + /// The change takes effect the next time a frame is decoded. In other + /// words, if a frame is currently in process of being decoded with a frame + /// size greater than `val` but less than the max frame length in effect + /// before calling this function, then the frame will be allowed. + pub fn set_max_frame_length(&mut self, val: usize) { + self.builder.max_frame_length(val); + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> { + let head_len = self.builder.num_head_bytes(); + let field_len = self.builder.length_field_len; + + if src.len() < head_len { + // Not enough data + return Ok(None); + } + + let n = { + let mut src = Cursor::new(&mut *src); + + // Skip the required bytes + src.advance(self.builder.length_field_offset); + + // match endianess + let n = if self.builder.length_field_is_big_endian { + src.get_uint(field_len) + } else { + src.get_uint_le(field_len) + }; + + if n > self.builder.max_frame_len as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // The check above ensures there is no overflow + let n = n as usize; + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_sub(-self.builder.length_adjustment as usize) + } else { + n.checked_add(self.builder.length_adjustment as usize) + }; + + // Error handling + match n { + Some(n) => n, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + )); + } + } + }; + + let num_skip = self.builder.get_num_skip(); + + if num_skip > 0 { + let _ = src.split_to(num_skip); + } + + // Ensure that the buffer has enough space to read the incoming + // payload + src.reserve(n); + + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + // At this point, the buffer has already had the required capacity + // reserved. All there is to do is read. + if src.len() < n { + return Ok(None); + } + + Ok(Some(src.split_to(n))) + } +} + +impl Decoder for LengthDelimitedCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src)? { + Some(data) => { + // Update the decode state + self.state = DecodeState::Head; + + // Make sure the buffer has enough space to read the next head + src.reserve(self.builder.num_head_bytes()); + + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder for LengthDelimitedCodec { + type Item = Bytes; + type Error = io::Error; + + fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { + let n = (&data).remaining(); + + if n > self.builder.max_frame_len { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_add(-self.builder.length_adjustment as usize) + } else { + n.checked_sub(self.builder.length_adjustment as usize) + }; + + let n = n.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + ) + })?; + + // Reserve capacity in the destination buffer to fit the frame and + // length field (plus adjustment). + dst.reserve(self.builder.length_field_len + n); + + if self.builder.length_field_is_big_endian { + dst.put_uint(n as u64, self.builder.length_field_len); + } else { + dst.put_uint_le(n as u64, self.builder.length_field_len); + } + + // Write the frame to the buffer + dst.extend_from_slice(&data[..]); + + Ok(()) + } +} + +impl Default for LengthDelimitedCodec { + fn default() -> Self { + Self::new() + } +} + +// ===== impl Builder ===== + +impl Builder { + /// Creates a new length delimited codec builder with default configuration + /// values. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new() -> Builder { + Builder { + // Default max frame length of 8MB + max_frame_len: 8 * 1_024 * 1_024, + + // Default byte length of 4 + length_field_len: 4, + + // Default to the header field being at the start of the header. + length_field_offset: 0, + + length_adjustment: 0, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: None, + + // Default to reading the length field in network (big) endian. + length_field_is_big_endian: true, + } + } + + /// Read the length field as a big endian integer + /// + /// This is the default setting. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .big_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn big_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = true; + self + } + + /// Read the length field as a little endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .little_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn little_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = false; + self + } + + /// Read the length field as a native endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .native_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn native_endian(&mut self) -> &mut Self { + if cfg!(target_endian = "big") { + self.big_endian() + } else { + self.little_endian() + } + } + + /// Sets the max frame length + /// + /// This configuration option applies to both encoding and decoding. The + /// default value is 8MB. + /// + /// When decoding, the length field read from the byte stream is checked + /// against this setting **before** any adjustments are applied. When + /// encoding, the length of the submitted payload is checked against this + /// setting. + /// + /// When frames exceed the max length, an `io::Error` with the custom value + /// of the `LengthDelimitedCodecError` type will be returned. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .max_frame_length(8 * 1024) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn max_frame_length(&mut self, val: usize) -> &mut Self { + self.max_frame_len = val; + self + } + + /// Sets the number of bytes used to represent the length field + /// + /// The default value is `4`. The max value is `8`. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_length(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_length(&mut self, val: usize) -> &mut Self { + assert!(val > 0 && val <= 8, "invalid length field length"); + self.length_field_len = val; + self + } + + /// Sets the number of bytes in the header before the length field + /// + /// This configuration option only applies to decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(1) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_offset(&mut self, val: usize) -> &mut Self { + self.length_field_offset = val; + self + } + + /// Delta between the payload length specified in the header and the real + /// payload length + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_adjustment(-2) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_adjustment(&mut self, val: isize) -> &mut Self { + self.length_adjustment = val; + self + } + + /// Sets the number of bytes to skip before reading the payload + /// + /// Default value is `length_field_len + length_field_offset` + /// + /// This configuration option only applies to decoding + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .num_skip(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn num_skip(&mut self, val: usize) -> &mut Self { + self.num_skip = Some(val); + self + } + + /// Create a configured length delimited `LengthDelimitedCodec` + /// + /// # Examples + /// + /// ``` + /// use tokio_util::codec::LengthDelimitedCodec; + /// # pub fn main() { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_codec(); + /// # } + /// ``` + pub fn new_codec(&self) -> LengthDelimitedCodec { + LengthDelimitedCodec { + builder: *self, + state: DecodeState::Head, + } + } + + /// Create a configured length delimited `FramedRead` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec> + where + T: AsyncRead, + { + FramedRead::new(upstream, self.new_codec()) + } + + /// Create a configured length delimited `FramedWrite` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncWrite; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncWrite>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_length(2) + /// .new_write(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec> + where + T: AsyncWrite, + { + FramedWrite::new(inner, self.new_codec()) + } + + /// Create a configured length delimited `Framed` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) { + /// # let _ = + /// LengthDelimitedCodec::builder() + /// .length_field_length(2) + /// .new_framed(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec> + where + T: AsyncRead + AsyncWrite, + { + Framed::new(inner, self.new_codec()) + } + + fn num_head_bytes(&self) -> usize { + let num = self.length_field_offset + self.length_field_len; + cmp::max(num, self.num_skip.unwrap_or(0)) + } + + fn get_num_skip(&self) -> usize { + self.num_skip + .unwrap_or(self.length_field_offset + self.length_field_len) + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +// ===== impl LengthDelimitedCodecError ===== + +impl fmt::Debug for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LengthDelimitedCodecError").finish() + } +} + +impl fmt::Display for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("frame size too big") + } +} + +impl StdError for LengthDelimitedCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/lines_codec.rs b/third_party/rust/tokio-util/src/codec/lines_codec.rs new file mode 100644 index 0000000000..8029956ff0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/lines_codec.rs @@ -0,0 +1,224 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +/// A simple `Codec` implementation that splits up data into lines. +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LinesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl LinesCodec { + /// Returns a `LinesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: #method.new_with_max_length + pub fn new() -> LinesCodec { + LinesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + /// Returns a `LinesCodec` with a maximum line length limit. + /// + /// If this is set, calls to `LinesCodec::decode` will return a + /// [`LengthError`] when a line exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that line until a newline + /// character is reached, returning `None` until the line over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `LinesCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the line currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any `\n` characters, causing unbounded memory consumption. + /// + /// [`LengthError`]: ../struct.LengthError + pub fn new_with_max_length(max_length: usize) -> Self { + LinesCodec { + max_length, + ..LinesCodec::new() + } + } + + /// Returns the maximum line length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new(); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new_with_max_length(256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn utf8(buf: &[u8]) -> Result<&str, io::Error> { + str::from_utf8(buf) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +impl Decoder for LinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Err(LinesCodecError::MaxLineLengthExceeded); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + let line = utf8(line)?; + return Ok(Some(line.to_string())); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let line = buf.split_to(buf.len()); + let line = without_carriage_return(&line); + let line = utf8(line)?; + self.next_index = 0; + Some(line.to_string()) + } + } + }) + } +} + +impl Encoder for LinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn encode(&mut self, line: String, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + buf.reserve(line.len() + 1); + buf.put(line.as_bytes()); + buf.put_u8(b'\n'); + Ok(()) + } +} + +impl Default for LinesCodec { + fn default() -> Self { + Self::new() + } +} + +/// An error occured while encoding or decoding a line. +#[derive(Debug)] +pub enum LinesCodecError { + /// The maximum line length was exceeded. + MaxLineLengthExceeded, + /// An IO error occured. + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From<io::Error> for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/mod.rs b/third_party/rust/tokio-util/src/codec/mod.rs new file mode 100644 index 0000000000..b162dd3a78 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/mod.rs @@ -0,0 +1,34 @@ +//! Utilities for encoding and decoding frames. +//! +//! Contains adapters to go from streams of bytes, [`AsyncRead`] and +//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. +//! Framed streams are also known as transports. +//! +//! [`AsyncRead`]: https://docs.rs/tokio/*/tokio/io/trait.AsyncRead.html +//! [`AsyncWrite`]: https://docs.rs/tokio/*/tokio/io/trait.AsyncWrite.html +//! [`Sink`]: https://docs.rs/futures-sink/*/futures_sink/trait.Sink.html +//! [`Stream`]: https://docs.rs/futures-core/*/futures_core/stream/trait.Stream.html + +mod bytes_codec; +pub use self::bytes_codec::BytesCodec; + +mod decoder; +pub use self::decoder::Decoder; + +mod encoder; +pub use self::encoder::Encoder; + +mod framed; +pub use self::framed::{Framed, FramedParts}; + +mod framed_read; +pub use self::framed_read::FramedRead; + +mod framed_write; +pub use self::framed_write::FramedWrite; + +pub mod length_delimited; +pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; + +mod lines_codec; +pub use self::lines_codec::{LinesCodec, LinesCodecError}; diff --git a/third_party/rust/tokio-util/src/lib.rs b/third_party/rust/tokio-util/src/lib.rs new file mode 100644 index 0000000000..4cb54dfb35 --- /dev/null +++ b/third_party/rust/tokio-util/src/lib.rs @@ -0,0 +1,26 @@ +#![doc(html_root_url = "https://docs.rs/tokio-util/0.2.0")] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![deny(intra_doc_link_resolution_failure)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Utilities for working with Tokio. + +#[macro_use] +mod cfg; + +cfg_codec! { + pub mod codec; +} + +cfg_udp! { + pub mod udp; +} diff --git a/third_party/rust/tokio-util/src/udp/frame.rs b/third_party/rust/tokio-util/src/udp/frame.rs new file mode 100644 index 0000000000..a6c6f22070 --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/frame.rs @@ -0,0 +1,181 @@ +use crate::codec::{Decoder, Encoder}; + +use tokio::net::UdpSocket; + +use bytes::{BufMut, BytesMut}; +use futures_core::{ready, Stream}; +use futures_sink::Sink; +use std::io; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A unified `Stream` and `Sink` interface to an underlying `UdpSocket`, using +/// the `Encoder` and `Decoder` traits to encode and decode frames. +/// +/// Raw UDP sockets work with datagrams, but higher-level code usually wants to +/// batch these into meaningful chunks, called "frames". This method layers +/// framing on top of this socket by using the `Encoder` and `Decoder` traits to +/// handle encoding and decoding of messages frames. Note that the incoming and +/// outgoing frame types may be distinct. +/// +/// This function returns a *single* object that is both `Stream` and `Sink`; +/// grouping this into a single object is often useful for layering things which +/// require both read and write access to the underlying object. +/// +/// If you want to work more directly with the streams and sink, consider +/// calling `split` on the `UdpFramed` returned by this method, which will break +/// them into separate objects, allowing them to interact more easily. +#[must_use = "sinks do nothing unless polled"] +#[cfg_attr(docsrs, doc(feature = "codec-udp"))] +#[derive(Debug)] +pub struct UdpFramed<C> { + socket: UdpSocket, + codec: C, + rd: BytesMut, + wr: BytesMut, + out_addr: SocketAddr, + flushed: bool, +} + +impl<C: Decoder + Unpin> Stream for UdpFramed<C> { + type Item = Result<(C::Item, SocketAddr), C::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let pin = self.get_mut(); + + pin.rd.reserve(INITIAL_RD_CAPACITY); + + let (_n, addr) = unsafe { + // Read into the buffer without having to initialize the memory. + // + // safety: we know tokio::net::UdpSocket never reads from the memory + // during a recv + let res = { + let bytes = &mut *(pin.rd.bytes_mut() as *mut _ as *mut [u8]); + ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, bytes)) + }; + + let (n, addr) = res?; + pin.rd.advance_mut(n); + (n, addr) + }; + + let frame_res = pin.codec.decode(&mut pin.rd); + pin.rd.clear(); + let frame = frame_res?; + let result = frame.map(|frame| Ok((frame, addr))); // frame -> (frame, addr) + + Poll::Ready(result) + } +} + +impl<C: Encoder + Unpin> Sink<(C::Item, SocketAddr)> for UdpFramed<C> { + type Error = C::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if !self.flushed { + match self.poll_flush(cx)? { + Poll::Ready(()) => {} + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: (C::Item, SocketAddr)) -> Result<(), Self::Error> { + let (frame, out_addr) = item; + + let pin = self.get_mut(); + + pin.codec.encode(frame, &mut pin.wr)?; + pin.out_addr = out_addr; + pin.flushed = false; + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if self.flushed { + return Poll::Ready(Ok(())); + } + + let Self { + ref mut socket, + ref mut out_addr, + ref mut wr, + .. + } = *self; + + let n = ready!(socket.poll_send_to(cx, &wr, &out_addr))?; + + let wrote_all = n == self.wr.len(); + self.wr.clear(); + self.flushed = true; + + let res = if wrote_all { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "failed to write entire datagram to socket", + ) + .into()) + }; + + Poll::Ready(res) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} + +const INITIAL_RD_CAPACITY: usize = 64 * 1024; +const INITIAL_WR_CAPACITY: usize = 8 * 1024; + +impl<C> UdpFramed<C> { + /// Create a new `UdpFramed` backed by the given socket and codec. + /// + /// See struct level documentation for more details. + pub fn new(socket: UdpSocket, codec: C) -> UdpFramed<C> { + UdpFramed { + socket, + codec, + out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), + rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), + wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), + flushed: true, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_ref(&self) -> &UdpSocket { + &self.socket + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_mut(&mut self) -> &mut UdpSocket { + &mut self.socket + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + pub fn into_inner(self) -> UdpSocket { + self.socket + } +} diff --git a/third_party/rust/tokio-util/src/udp/mod.rs b/third_party/rust/tokio-util/src/udp/mod.rs new file mode 100644 index 0000000000..7c4bb2b3cb --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/mod.rs @@ -0,0 +1,4 @@ +//! UDP framing + +mod frame; +pub use self::frame::UdpFramed; diff --git a/third_party/rust/tokio-util/tests/codecs.rs b/third_party/rust/tokio-util/tests/codecs.rs new file mode 100644 index 0000000000..d121286657 --- /dev/null +++ b/third_party/rust/tokio-util/tests/codecs.rs @@ -0,0 +1,217 @@ +#![warn(rust_2018_idioms)] + +use tokio_util::codec::{BytesCodec, Decoder, Encoder, LinesCodec}; + +use bytes::{BufMut, Bytes, BytesMut}; + +#[test] +fn bytes_decoder() { + let mut codec = BytesCodec::new(); + let buf = &mut BytesMut::new(); + buf.put_slice(b"abc"); + assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"a"); + assert_eq!("a", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn bytes_encoder() { + let mut codec = BytesCodec::new(); + + // Default capacity of BytesMut + #[cfg(target_pointer_width = "64")] + const INLINE_CAP: usize = 4 * 8 - 1; + #[cfg(target_pointer_width = "32")] + const INLINE_CAP: usize = 4 * 4 - 1; + + let mut buf = BytesMut::new(); + codec + .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) + .unwrap(); + + // Default capacity of Framed Read + const INITIAL_CAPACITY: usize = 8 * 1024; + + let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); + codec + .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) + .unwrap(); +} + +#[test] +fn lines_decoder() { + let mut codec = LinesCodec::new(); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); + assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn lines_decoder_max_length() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 2", line); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 4", line); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let line = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("\rk", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Line that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"line 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"\n"); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_newline_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"\nworld"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +// Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) +#[test] +fn lines_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_encoder() { + let mut codec = LinesCodec::new(); + let mut buf = BytesMut::new(); + + codec.encode(String::from("line 1"), &mut buf).unwrap(); + assert_eq!("line 1\n", buf); + + codec.encode(String::from("line 2"), &mut buf).unwrap(); + assert_eq!("line 1\nline 2\n", buf); +} diff --git a/third_party/rust/tokio-util/tests/framed.rs b/third_party/rust/tokio-util/tests/framed.rs new file mode 100644 index 0000000000..b98df7368d --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed.rs @@ -0,0 +1,97 @@ +#![warn(rust_2018_idioms)] + +use tokio::prelude::*; +use tokio_test::assert_ok; +use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; + +use bytes::{Buf, BufMut, BytesMut}; +use futures::StreamExt; +use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Encode and decode u32 values. +struct U32Codec; + +impl Decoder for U32Codec { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + Ok(Some(n)) + } +} + +impl Encoder for U32Codec { + type Item = u32; + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +/// This value should never be used +struct DontReadIntoThis; + +impl Read for DontReadIntoThis { + fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { + Err(io::Error::new( + io::ErrorKind::Other, + "Read into something you weren't supposed to.", + )) + } +} + +impl AsyncRead for DontReadIntoThis { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + unreachable!() + } +} + +#[tokio::test] +async fn can_read_from_existing_buf() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); +} + +#[test] +fn external_buf_grows_to_init() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); +} + +#[test] +fn external_buf_does_not_shrink() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); +} diff --git a/third_party/rust/tokio-util/tests/framed_read.rs b/third_party/rust/tokio-util/tests/framed_read.rs new file mode 100644 index 0000000000..06caa0a4d0 --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_read.rs @@ -0,0 +1,295 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::AsyncRead; +use tokio_test::assert_ready; +use tokio_test::task; +use tokio_util::codec::{Decoder, FramedRead}; + +use bytes::{Buf, BytesMut}; +use futures::Stream; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Decoder; + +impl Decoder for U32Decoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + Ok(Some(n)) + } +} + +#[test] +fn read_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_across_packets() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_partial_then_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_err() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_would_block_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn huge_size() { + let mut task = task::spawn(()); + let data = [0; 32 * 1024]; + let mut framed = FramedRead::new(Slice(&data[..]), BigDecoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); + + struct BigDecoder; + + impl Decoder for BigDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 32 * 1024 { + return Ok(None); + } + buf.split_to(32 * 1024); + Ok(Some(0)) + } + } +} + +#[test] +fn data_remaining_is_error() { + let mut task = task::spawn(()); + let slice = Slice(&[0; 5]); + let mut framed = FramedRead::new(slice, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); + }); +} + +#[test] +fn multi_frames_on_eof() { + let mut task = task::spawn(()); + struct MyDecoder(Vec<u32>); + + impl Decoder for MyDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + unreachable!(); + } + + fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + if self.0.is_empty() { + return Ok(None); + } + + Ok(Some(self.0.remove(0))) + } + } + + let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +// ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + use io::ErrorKind::WouldBlock; + + match self.calls.pop_front() { + Some(Ok(data)) => { + debug_assert!(buf.len() >= data.len()); + buf[..data.len()].copy_from_slice(&data[..]); + Ready(Ok(data.len())) + } + Some(Err(ref e)) if e.kind() == WouldBlock => Pending, + Some(Err(e)) => Ready(Err(e)), + None => Ready(Ok(0)), + } + } +} + +// TODO this newtype is necessary because `&[u8]` does not currently implement `AsyncRead` +struct Slice<'a>(&'a [u8]); + +impl AsyncRead for Slice<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} diff --git a/third_party/rust/tokio-util/tests/framed_write.rs b/third_party/rust/tokio-util/tests/framed_write.rs new file mode 100644 index 0000000000..706e6792fe --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_write.rs @@ -0,0 +1,173 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::AsyncWrite; +use tokio_test::{assert_ready, task}; +use tokio_util::codec::{Encoder, FramedWrite}; + +use bytes::{BufMut, BytesMut}; +use futures_sink::Sink; +use std::collections::VecDeque; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Encoder; + +impl Encoder for U32Encoder { + type Item = u32; + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +#[test] +fn write_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(1).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(2).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_hits_backpressure() { + const ITER: usize = 2 * 1024; + + let mut mock = mock! { + // Block the `ITER`th write + Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), + Ok(b"".to_vec()), + }; + + for i in 0..=ITER { + let mut b = BytesMut::with_capacity(4); + b.put_u32(i as u32); + + // Append to the end + match mock.calls.back_mut().unwrap() { + &mut Ok(ref mut data) => { + // Write in 2kb chunks + if data.len() < ITER { + data.extend_from_slice(&b[..]); + continue; + } // else fall through and create a new buffer + } + _ => unreachable!(), + } + + // Push a new new chunk + mock.calls.push_back(Ok(b[..].to_vec())); + } + // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer + assert_eq!(mock.calls.len(), 6); + + let mut task = task::spawn(()); + let mut framed = FramedWrite::new(mock, U32Encoder); + task.enter(|cx, _| { + // Send 8KB. This fills up FramedWrite2 buffer + for i in 0..ITER { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(i as u32).is_ok()); + } + + // Now we poll_ready which forces a flush. The mock pops the front message + // and decides to block. + assert!(pin!(framed).poll_ready(cx).is_pending()); + + // We poll again, forcing another flush, which this time succeeds + // The whole 8KB buffer is flushed + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + + // Send more data. This matches the final message expected by the mock + assert!(pin!(framed).start_send(ITER as u32).is_ok()); + + // Flush the rest of the buffer + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + // Ensure the mock is empty + assert_eq!(0, framed.get_ref().calls.len()); + }) +} + +// // ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl Write for Mock { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + match self.calls.pop_front() { + Some(Ok(data)) => { + assert!(src.len() >= data.len()); + assert_eq!(&data[..], &src[..data.len()]); + Ok(data.len()) + } + Some(Err(e)) => Err(e), + None => panic!("unexpected write; {:?}", src), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match Pin::get_mut(self).write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match Pin::get_mut(self).flush() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + unimplemented!() + } +} diff --git a/third_party/rust/tokio-util/tests/length_delimited.rs b/third_party/rust/tokio-util/tests/length_delimited.rs new file mode 100644 index 0000000000..6c5199167b --- /dev/null +++ b/third_party/rust/tokio-util/tests/length_delimited.rs @@ -0,0 +1,760 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_test::task; +use tokio_test::{ + assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, +}; +use tokio_util::codec::*; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{pin_mut, Sink, Stream}; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::*; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_next_eq { + ($io:ident, $expect:expr) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), + Some(Err(e)) => panic!("error = {:?}", e), + None => panic!("none"), + } + }); + }}; +} + +macro_rules! assert_next_pending { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(e))) => panic!("error = {:?}", e), + Ready(None) => panic!("done"), + Pending => {} + }); + }}; +} + +macro_rules! assert_next_err { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(_))) => {} + Ready(None) => panic!("done"), + Pending => panic!("pending"), + }); + }}; +} + +macro_rules! assert_done { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => panic!("value = {:?}", v), + Some(Err(e)) => panic!("error = {:?}", e), + None => {} + } + }); + }}; +} + +#[test] +fn read_empty_io_yields_nothing() { + let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); + pin_mut!(io); + + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_read(mock! { + data(b"\x09\x00\x00\x00abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_native_endian() { + let d = if cfg!(target_endian = "big") { + b"\x00\x00\x00\x09abcdefghi" + } else { + b"\x09\x00\x00\x00abcdefghi" + }; + let io = length_delimited::Builder::new() + .native_endian() + .new_read(mock! { + data(d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); + d.extend_from_slice(b"\x00\x00\x00\x03123"); + d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); + + let io = FramedRead::new( + mock! { + data(&d), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + data(b"\x00\x00\x00\x0312"), + data(b"3\x00\x00\x00\x0bhello world"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + data(b"\x00\x00\x00\x0312"), + Pending, + data(b"3\x00\x00\x00\x0bhello world"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_incomplete_head() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_incomplete_head_multi() { + let io = FramedRead::new( + mock! { + Pending, + data(b"\x00"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_incomplete_payload() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09ab"), + Pending, + data(b"cd"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + io.decoder_mut().set_max_frame_length(5); + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcd"), + Pending, + data(b"efghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_pending!(io); + io.decoder_mut().set_max_frame_length(5); + assert_next_eq!(io, b"abcdefghi"); + assert_next_err!(io); +} + +#[test] +fn read_one_byte_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_read(mock! { + data(b"\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_header_offset() { + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(4) + .new_read(mock! { + data(b"zzzz\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_skip_none_adjusted() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"xx\x00\x09abcdefghi"); + d.extend_from_slice(b"yy\x00\x03123"); + d.extend_from_slice(b"zz\x00\x0bhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(2) + .num_skip(0) + .length_adjustment(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"xx\x00\x09abcdefghi"); + assert_next_eq!(io, b"yy\x00\x03123"); + assert_next_eq!(io, b"zz\x00\x0bhello world"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_length_includes_head() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x0babcdefghi"); + d.extend_from_slice(b"\x00\x05123"); + d.extend_from_slice(b"\x00\x0dhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_adjustment(-2) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn write_single_frame_length_adjusted() { + let io = length_delimited::Builder::new() + .length_adjustment(-2) + .new_write(mock! { + data(b"\x00\x00\x00\x0b"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_nothing_yields_nothing() { + let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.poll_flush(cx)); + }); +} + +#[test] +fn write_single_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + data(b"\x00\x00\x00\x03"), + data(b"123"), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_multi_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + data(b"\x00\x00\x00\x03"), + data(b"123"), + flush(), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_would_block() { + let io = FramedWrite::new( + mock! { + Pending, + data(b"\x00\x00"), + Pending, + data(b"\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + assert_pending!(io.as_mut().poll_flush(cx)); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_write(mock! { + data(b"\x09\x00\x00\x00"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_with_short_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_write(mock! { + data(b"\x09"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"abcdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"ab"), + Pending, + data(b"cdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_zero() { + let io = length_delimited::Builder::new().new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_err!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn encode_overflow() { + // Test reproducing tokio-rs/tokio#681. + let mut codec = length_delimited::Builder::new().new_codec(); + let mut buf = BytesMut::with_capacity(1024); + + // Put some data into the buffer without resizing it to hold more. + let some_as = std::iter::repeat(b'a').take(1024).collect::<Vec<_>>(); + buf.put_slice(&some_as[..]); + + // Trying to encode the length header should resize the buffer if it won't fit. + codec.encode(Bytes::from("hello"), &mut buf).unwrap(); +} + +// ===== Test utils ===== + +struct Mock { + calls: VecDeque<Poll<io::Result<Op>>>, +} + +enum Op { + Data(Vec<u8>), + Flush, +} + +use self::Op::*; + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + dst: &mut [u8], + ) -> Poll<io::Result<usize>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + debug_assert!(dst.len() >= data.len()); + dst[..data.len()].copy_from_slice(&data[..]); + Ready(Ok(data.len())) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(0)), + } + } +} + +impl AsyncWrite for Mock { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + src: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + let len = data.len(); + assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); + assert_eq!(&data[..], &src[..len]); + Ready(Ok(len)) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(0)), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ready(Ok(())) + } +} + +impl<'a> From<&'a [u8]> for Op { + fn from(src: &'a [u8]) -> Op { + Op::Data(src.into()) + } +} + +impl From<Vec<u8>> for Op { + fn from(src: Vec<u8>) -> Op { + Op::Data(src) + } +} + +fn data(bytes: &[u8]) -> Poll<io::Result<Op>> { + Ready(Ok(bytes.into())) +} + +fn flush() -> Poll<io::Result<Op>> { + Ready(Ok(Flush)) +} diff --git a/third_party/rust/tokio-util/tests/udp.rs b/third_party/rust/tokio-util/tests/udp.rs new file mode 100644 index 0000000000..af8002bd80 --- /dev/null +++ b/third_party/rust/tokio-util/tests/udp.rs @@ -0,0 +1,79 @@ +use tokio::net::UdpSocket; +use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::udp::UdpFramed; + +use bytes::{BufMut, BytesMut}; +use futures::future::try_join; +use futures::future::FutureExt; +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use std::io; + +#[tokio::test] +async fn send_framed() -> std::io::Result<()> { + let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + // test sending & receiving bytes + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b"4567".to_vec(); + + let send = a.send((msg.clone(), b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, data); + assert_eq!(a_addr, addr); + + a_soc = a.into_inner(); + b_soc = b.into_inner(); + } + + // test sending & receiving an empty message + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b"".to_vec(); + + let send = a.send((msg.clone(), b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, data); + assert_eq!(a_addr, addr); + } + + Ok(()) +} + +pub struct ByteCodec; + +impl Decoder for ByteCodec { + type Item = Vec<u8>; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Vec<u8>>, io::Error> { + let len = buf.len(); + Ok(Some(buf.split_to(len).to_vec())) + } +} + +impl Encoder for ByteCodec { + type Item = Vec<u8>; + type Error = io::Error; + + fn encode(&mut self, data: Vec<u8>, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put_slice(&data); + Ok(()) + } +} |