diff options
Diffstat (limited to 'third_party/rust/tokio-util')
65 files changed, 12959 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/.cargo-checksum.json b/third_party/rust/tokio-util/.cargo-checksum.json new file mode 100644 index 0000000000..d6d315e372 --- /dev/null +++ b/third_party/rust/tokio-util/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"CHANGELOG.md":"06d030881733c323d8c9cc92a0b8241947aaeaa979e2ea7ec1a42c91394c83c5","Cargo.toml":"1c8557b55277e84909c7e0ed2fcadb1a3c8590ea342ec1ed605aa4b9e7d0ca3e","LICENSE":"697fc7385b1b0593f77d00db6e3ae8c146c2ccef505c4f09327bbedf952bfe35","README.md":"91c8da557ba5fbfb5a9b6e58d5efd5fd1700dd836509cf017628155c249e192c","src/cfg.rs":"800248e35ac58cbff4327959990e83783cf0e6dd82fec4ccf3fd55038a92115c","src/codec/any_delimiter_codec.rs":"66a4c3aee5328ef8a8be20a58d6ce388bda2394bc70e4800cf52e95760a22e09","src/codec/bytes_codec.rs":"e8f14a93415768f5a8736cbcc0a1684742e3be6df1a6b60cb95bd146544eee74","src/codec/decoder.rs":"c3f6c5197f80412684c15f905fd5172e0ffe3b8bad9589e8bfa16fe5f1a92f81","src/codec/encoder.rs":"e4544af47cdde075d1238ddee9555037912089bf25ce51cb4dd6f5d4f76ecf70","src/codec/framed.rs":"3137b9f2480429d00ee6fb2b567da78f41bb0dcba775e14956c0d93409cbbe38","src/codec/framed_impl.rs":"1bdbbebd724d2734c887cb1c90c78061ea212e9d2c0a9e91c390f81381395a03","src/codec/framed_read.rs":"367cc5517513c8fe256d8175b0497402f8b18dbf596aa19e6aa0c714772f0619","src/codec/framed_write.rs":"ca6a714e94c9778a7f01624299d2a58fa310dfe486c19e056d5d12c62d6be4fb","src/codec/length_delimited.rs":"4d2a62dfa2c9cf1a7242c5b04283cb3e88817844fee78c491e5aa76e5bf279da","src/codec/lines_codec.rs":"912302c500ea224e9002936e50505a5a29911e13d8c615b74954553f22d59826","src/codec/mod.rs":"95a11e3018bb850ff5934ef48cbc5056bdf2758d130cfa357ad406f0531946ae","src/compat.rs":"719b3a4ee8534647ae72df2d1a7b4937c60a9ee41e018fa7305dc6d5b3b41ed6","src/context.rs":"45a23756c6ce6b834da0f1817f556cc5bdd16412ddfc1dc9082da8a56819741a","src/either.rs":"25e022d51a44490e175b525d4493dd9e6fa51bd03aa27b763be9509eb7c4c0ee","src/io/mod.rs":"e2bf2cc05d6b57fa3cafcf95f5eb73996edd090fc012a99a1c4ad915276b80c2","src/io/read_buf.rs":"7043c2fbec74e548395eb2f12073c41c1b30e2f2c283b30eddfb5a16125387d0","src/io/reader_stream.rs":"98d0819ef38f94d56d7085a82d29fd83bde92a9178bebfe73c4533d0022b3d94","src/io/stream_reader.rs":"f36f95178b61f8498929dfc53416558037667d913d1765082e77a0b45460ea77","src/io/sync_bridge.rs":"06eed8295906e1a746e071433c185424d96f75e988a0c4dcb70e3efc7cc50513","src/lib.rs":"769afc23670c71d441233a5c0dffe12799518fcfffbdf2073ba798d0e3e4a104","src/loom.rs":"9028ba504687ad5ec98f3c69387bc2424ec23324b1eed3df5082a9cf23c6502e","src/net/mod.rs":"a75736c1d71408b4f5fb0f0bdcb535cc113e430a2479e01e5fca733ef3fcb15b","src/net/unix/mod.rs":"d667dbb53d7003a15a4705ca0654b35be7165b56ac0d631a23e90d18027a1a90","src/sync/cancellation_token.rs":"8bc59e142a2b3576ccdfb248957c627b28cd0de5d2aa20fccf74d1cdd163fe13","src/sync/cancellation_token/guard.rs":"6582ba676c4126b4fa292221f94ddcb5fd6d3e0e2a794819fa424a8f661bff8a","src/sync/cancellation_token/tree_node.rs":"4b46e5c3247387abdc421dc58c0bcc31166ef873a4847933fa354cb78eac58cc","src/sync/mod.rs":"385fbc1c98c330644adb76a3333b6a5c1f644a00ab5735d84293bd2ef878c18e","src/sync/mpsc.rs":"b0c6af8395ae5779c31ea08a307aa37c4138953af1e69d6b6f94efa485eb1da1","src/sync/poll_semaphore.rs":"817b520a5bb3b84bff6008a06ef0f6d5d256574a3ca9bf89d01a609d669790c3","src/sync/reusable_box.rs":"9b486884a036e9af3683945523714ce93db5d309454fd3ff198ccc357c8ef0c4","src/sync/tests/loom_cancellation_token.rs":"6393c5a12f09abef9300be4b66bb039bf02a63a04d6175fb7cfe68464784bdbd","src/sync/tests/mod.rs":"01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b","src/task/mod.rs":"f61140aaff3e34b7005b0e88846e9af4440ec2e37045c0323d93e7d954f747b4","src/task/spawn_pinned.rs":"87519f35a28142decdbbe8d7598da60250359de3622e99eac91f6d5879609119","src/time/delay_queue.rs":"ab93368c84aaa5ae3d52e1a3939a0aa70e2d5ce9406f4e888666fddb7e119c7d","src/time/mod.rs":"a76126419f30539dcceb72477251de79b38fe38278ef08b0743f1f716294fd9d","src/time/wheel/level.rs":"75f6b29212e0a58aa3196c78688cc9e23272d8eae3fc829d8b381ab49de2372b","src/time/wheel/mod.rs":"4d97f3d7130553adabb1866831962e2665fabc323b3781693bfc803387ef25a4","src/time/wheel/stack.rs":"648d3c071e9754a820343c53de0cbe9c07e47276bb04155e67873c276ed13a61","src/udp/frame.rs":"5afa77955b497c0e2812705f8cd9517b5439847c1381d2e3939eab28c489a578","src/udp/mod.rs":"869302c0c15eb41f7079ef12ce03fa7c889a83152282ab0b7faf73d1c092ed4e","tests/_require_full.rs":"f8dedb6ad88884209b074ff6f5faa4979528b32653b45ab8047d2ebb28e19122","tests/codecs.rs":"493df228f9dce98de69e0afa7be491d6fe4588e7a381a7910c28f248d2d038a6","tests/context.rs":"917f80db694b54db07e6d1660aa5210272efda71cc0203f062dfecd81a8289a7","tests/framed.rs":"4e808fbc8d601138ff787b3603a377c23b3f42b4a7b882f9a1eb8cf1234c89e3","tests/framed_read.rs":"df41071388645518cc6b0700b75dd678219214b04de42605a0122f007f4ed281","tests/framed_stream.rs":"c3118fc5db62f225ad6d97f8b32ac03812b3b68cdab7a94d189f4a4d9337f13f","tests/framed_write.rs":"1b311ae6d79616e41f20b6213f8585a9c65830578eddea2a012ba72a3a359611","tests/io_reader_stream.rs":"1c9f79782c5574c5e489e86205bdb63b332fc0e8560fc4c2d602dfc2d2741a5e","tests/io_stream_reader.rs":"d86e225eafbd196be3124147b1275bed384459e7ec3e8cb21775906253f75086","tests/io_sync_bridge.rs":"7852a934bbe497822423c8a75c4c0f3cd651f2e3d11fc7c35a01f723752b7e6c","tests/length_delimited.rs":"6bb4714c29b8b76ccdaddb59b1ea51f73499e8e223e8b08a62107a9190af4ba5","tests/mpsc.rs":"4f4c4edaaa295cb61d05900408a463345b790aec9b228260ffab41d90fecc4b5","tests/poll_semaphore.rs":"a04ffcf40cd0b65d8809ad4e881b579c20b3ef7de49d8c4094fe455fc3d1887d","tests/reusable_box.rs":"f12e98533443fd6c53ea586185c7e349a95c595bfd00930d764510592a5274cd","tests/spawn_pinned.rs":"f92e8a700c71074a29649036d17034b810da3c181b3afef8c33de04152fe2a12","tests/sync_cancellation_token.rs":"71c3f431384fc4313213f30893d44ec38582f712c855a5c9cd385d01f3e21c2c","tests/time_delay_queue.rs":"b522aff22601513cbdfe848802a8959ac3f27d78823fa430c3a9e6c6560024dd","tests/udp.rs":"c2f8d90eeae9d3b7f107c12f3723d54ba591ec9e879893e195ec13ecfcb4db27"},"package":"f988a1a1adc2fb21f9c12aa96441da33a1728193ae0b95d2be22dbd17fcb4e5c"}
\ No newline at end of file diff --git a/third_party/rust/tokio-util/CHANGELOG.md b/third_party/rust/tokio-util/CHANGELOG.md new file mode 100644 index 0000000000..d200cb380f --- /dev/null +++ b/third_party/rust/tokio-util/CHANGELOG.md @@ -0,0 +1,247 @@ +# 0.7.2 (May 14, 2022) + +This release contains a rewrite of `CancellationToken` that fixes a memory leak. ([#4652]) + +[#4652]: https://github.com/tokio-rs/tokio/pull/4652 + +# 0.7.1 (February 21, 2022) + +### Added + +- codec: add `length_field_type` to `LengthDelimitedCodec` builder ([#4508]) +- io: add `StreamReader::into_inner_with_chunk()` ([#4559]) + +### Changed + +- switch from log to tracing ([#4539]) + +### Fixed + +- sync: fix waker update condition in `CancellationToken` ([#4497]) +- bumped tokio dependency to 1.6 to satisfy minimum requirements ([#4490]) + +[#4490]: https://github.com/tokio-rs/tokio/pull/4490 +[#4497]: https://github.com/tokio-rs/tokio/pull/4497 +[#4508]: https://github.com/tokio-rs/tokio/pull/4508 +[#4539]: https://github.com/tokio-rs/tokio/pull/4539 +[#4559]: https://github.com/tokio-rs/tokio/pull/4559 + +# 0.7.0 (February 9, 2022) + +### Added + +- task: add `spawn_pinned` ([#3370]) +- time: add `shrink_to_fit` and `compact` methods to `DelayQueue` ([#4170]) +- codec: improve `Builder::max_frame_length` docs ([#4352]) +- codec: add mutable reference getters for codecs to pinned `Framed` ([#4372]) +- net: add generic trait to combine `UnixListener` and `TcpListener` ([#4385]) +- codec: implement `Framed::map_codec` ([#4427]) +- codec: implement `Encoder<BytesMut>` for `BytesCodec` ([#4465]) + +### Changed + +- sync: add lifetime parameter to `ReusableBoxFuture` ([#3762]) +- sync: refactored `PollSender<T>` to fix a subtly broken `Sink<T>` implementation ([#4214]) +- time: remove error case from the infallible `DelayQueue::poll_elapsed` ([#4241]) + +[#3370]: https://github.com/tokio-rs/tokio/pull/3370 +[#4170]: https://github.com/tokio-rs/tokio/pull/4170 +[#4352]: https://github.com/tokio-rs/tokio/pull/4352 +[#4372]: https://github.com/tokio-rs/tokio/pull/4372 +[#4385]: https://github.com/tokio-rs/tokio/pull/4385 +[#4427]: https://github.com/tokio-rs/tokio/pull/4427 +[#4465]: https://github.com/tokio-rs/tokio/pull/4465 +[#3762]: https://github.com/tokio-rs/tokio/pull/3762 +[#4214]: https://github.com/tokio-rs/tokio/pull/4214 +[#4241]: https://github.com/tokio-rs/tokio/pull/4241 + +# 0.6.10 (May 14, 2021) + +This is a backport for the memory leak in `CancellationToken` that was originally fixed in 0.7.2. ([#4652]) + +[#4652]: https://github.com/tokio-rs/tokio/pull/4652 + +# 0.6.9 (October 29, 2021) + +### Added + +- codec: implement `Clone` for `LengthDelimitedCodec` ([#4089]) +- io: add `SyncIoBridge` ([#4146]) + +### Fixed + +- time: update deadline on removal in `DelayQueue` ([#4178]) +- codec: Update stream impl for Framed to return None after Err ([#4166]) + +[#4089]: https://github.com/tokio-rs/tokio/pull/4089 +[#4146]: https://github.com/tokio-rs/tokio/pull/4146 +[#4166]: https://github.com/tokio-rs/tokio/pull/4166 +[#4178]: https://github.com/tokio-rs/tokio/pull/4178 + +# 0.6.8 (September 3, 2021) + +### Added + +- sync: add drop guard for `CancellationToken` ([#3839]) +- compact: added `AsyncSeek` compat ([#4078]) +- time: expose `Key` used in `DelayQueue`'s `Expired` ([#4081]) +- io: add `with_capacity` to `ReaderStream` ([#4086]) + +### Fixed + +- codec: remove unnecessary `doc(cfg(...))` ([#3989]) + +[#3839]: https://github.com/tokio-rs/tokio/pull/3839 +[#4078]: https://github.com/tokio-rs/tokio/pull/4078 +[#4081]: https://github.com/tokio-rs/tokio/pull/4081 +[#4086]: https://github.com/tokio-rs/tokio/pull/4086 +[#3989]: https://github.com/tokio-rs/tokio/pull/3989 + +# 0.6.7 (May 14, 2021) + +### Added + +- udp: make `UdpFramed` take `Borrow<UdpSocket>` ([#3451]) +- compat: implement `AsRawFd`/`AsRawHandle` for `Compat<T>` ([#3765]) + +[#3451]: https://github.com/tokio-rs/tokio/pull/3451 +[#3765]: https://github.com/tokio-rs/tokio/pull/3765 + +# 0.6.6 (April 12, 2021) + +### Added + +- util: makes `Framed` and `FramedStream` resumable after eof ([#3272]) +- util: add `PollSemaphore::{add_permits, available_permits}` ([#3683]) + +### Fixed + +- chore: avoid allocation if `PollSemaphore` is unused ([#3634]) + +[#3272]: https://github.com/tokio-rs/tokio/pull/3272 +[#3634]: https://github.com/tokio-rs/tokio/pull/3634 +[#3683]: https://github.com/tokio-rs/tokio/pull/3683 + +# 0.6.5 (March 20, 2021) + +### Fixed + +- util: annotate time module as requiring `time` feature ([#3606]) + +[#3606]: https://github.com/tokio-rs/tokio/pull/3606 + +# 0.6.4 (March 9, 2021) + +### Added + +- codec: `AnyDelimiter` codec ([#3406]) +- sync: add pollable `mpsc::Sender` ([#3490]) + +### Fixed + +- codec: `LinesCodec` should only return `MaxLineLengthExceeded` once per line ([#3556]) +- sync: fuse PollSemaphore ([#3578]) + +[#3406]: https://github.com/tokio-rs/tokio/pull/3406 +[#3490]: https://github.com/tokio-rs/tokio/pull/3490 +[#3556]: https://github.com/tokio-rs/tokio/pull/3556 +[#3578]: https://github.com/tokio-rs/tokio/pull/3578 + +# 0.6.3 (January 31, 2021) + +### Added + +- sync: add `ReusableBoxFuture` utility ([#3464]) + +### Changed + +- sync: use `ReusableBoxFuture` for `PollSemaphore` ([#3463]) +- deps: remove `async-stream` dependency ([#3463]) +- deps: remove `tokio-stream` dependency ([#3487]) + +# 0.6.2 (January 21, 2021) + +### Added + +- sync: add pollable `Semaphore` ([#3444]) + +### Fixed + +- time: fix panics on updating `DelayQueue` entries ([#3270]) + +# 0.6.1 (January 12, 2021) + +### Added + +- codec: `get_ref()`, `get_mut()`, `get_pin_mut()` and `into_inner()` for + `Framed`, `FramedRead`, `FramedWrite` and `StreamReader` ([#3364]). +- codec: `write_buffer()` and `write_buffer_mut()` for `Framed` and + `FramedWrite` ([#3387]). + +# 0.6.0 (December 23, 2020) + +### Changed +- depend on `tokio` 1.0. + +### Added +- rt: add constructors to `TokioContext` (#3221). + +# 0.5.1 (December 3, 2020) + +### Added +- io: `poll_read_buf` util fn (#2972). +- io: `poll_write_buf` util fn with vectored write support (#3156). + +# 0.5.0 (October 30, 2020) + +### Changed +- io: update `bytes` to 0.6 (#3071). + +# 0.4.0 (October 15, 2020) + +### Added +- sync: `CancellationToken` for coordinating task cancellation (#2747). +- rt: `TokioContext` sets the Tokio runtime for the duration of a future (#2791) +- io: `StreamReader`/`ReaderStream` map between `AsyncRead` values and `Stream` + of bytes (#2788). +- time: `DelayQueue` to manage many delays (#2897). + +# 0.3.1 (March 18, 2020) + +### Fixed + +- Adjust minimum-supported Tokio version to v0.2.5 to account for an internal + dependency on features in that version of Tokio. ([#2326]) + +# 0.3.0 (March 4, 2020) + +### Changed + +- **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows + codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746]) + +### Added + +- Add futures-io/tokio::io compatibility layer. ([#2117]) +- Add `Framed::with_capacity`. ([#2215]) + +### Fixed + +- Use advance over split_to when data is not needed. ([#2198]) + +# 0.2.0 (November 26, 2019) + +- Initial release + +[#3487]: https://github.com/tokio-rs/tokio/pull/3487 +[#3464]: https://github.com/tokio-rs/tokio/pull/3464 +[#3463]: https://github.com/tokio-rs/tokio/pull/3463 +[#3444]: https://github.com/tokio-rs/tokio/pull/3444 +[#3387]: https://github.com/tokio-rs/tokio/pull/3387 +[#3364]: https://github.com/tokio-rs/tokio/pull/3364 +[#3270]: https://github.com/tokio-rs/tokio/pull/3270 +[#2326]: https://github.com/tokio-rs/tokio/pull/2326 +[#2215]: https://github.com/tokio-rs/tokio/pull/2215 +[#2198]: https://github.com/tokio-rs/tokio/pull/2198 +[#2117]: https://github.com/tokio-rs/tokio/pull/2117 +[#1746]: https://github.com/tokio-rs/tokio/pull/1746 diff --git a/third_party/rust/tokio-util/Cargo.toml b/third_party/rust/tokio-util/Cargo.toml new file mode 100644 index 0000000000..001b0b78b1 --- /dev/null +++ b/third_party/rust/tokio-util/Cargo.toml @@ -0,0 +1,112 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +rust-version = "1.49" +name = "tokio-util" +version = "0.7.2" +authors = ["Tokio Contributors <team@tokio.rs>"] +description = """ +Additional utilities for working with Tokio. +""" +homepage = "https://tokio.rs" +categories = ["asynchronous"] +license = "MIT" +repository = "https://github.com/tokio-rs/tokio" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = [ + "--cfg", + "docsrs", +] + +[dependencies.bytes] +version = "1.0.0" + +[dependencies.futures-core] +version = "0.3.0" + +[dependencies.futures-io] +version = "0.3.0" +optional = true + +[dependencies.futures-sink] +version = "0.3.0" + +[dependencies.futures-util] +version = "0.3.0" +optional = true + +[dependencies.pin-project-lite] +version = "0.2.0" + +[dependencies.slab] +version = "0.4.4" +optional = true + +[dependencies.tokio] +version = "1.7.0" +features = ["sync"] + +[dependencies.tracing] +version = "0.1.25" +optional = true + +[dev-dependencies.async-stream] +version = "0.3.0" + +[dev-dependencies.futures] +version = "0.3.0" + +[dev-dependencies.futures-test] +version = "0.3.5" + +[dev-dependencies.tokio] +version = "1.0.0" +features = ["full"] + +[dev-dependencies.tokio-stream] +version = "0.1" + +[dev-dependencies.tokio-test] +version = "0.4.0" + +[features] +__docs_rs = ["futures-util"] +codec = ["tracing"] +compat = ["futures-io"] +default = [] +full = [ + "codec", + "compat", + "io-util", + "time", + "net", + "rt", +] +io = [] +io-util = [ + "io", + "tokio/rt", + "tokio/io-util", +] +net = ["tokio/net"] +rt = [ + "tokio/rt", + "tokio/sync", + "futures-util", +] +time = [ + "tokio/time", + "slab", +] diff --git a/third_party/rust/tokio-util/LICENSE b/third_party/rust/tokio-util/LICENSE new file mode 100644 index 0000000000..8af5baf01e --- /dev/null +++ b/third_party/rust/tokio-util/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2022 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/third_party/rust/tokio-util/README.md b/third_party/rust/tokio-util/README.md new file mode 100644 index 0000000000..0d74f36d9a --- /dev/null +++ b/third_party/rust/tokio-util/README.md @@ -0,0 +1,13 @@ +# tokio-util + +Utilities for working with Tokio. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tokio by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/third_party/rust/tokio-util/src/cfg.rs b/third_party/rust/tokio-util/src/cfg.rs new file mode 100644 index 0000000000..4035255aff --- /dev/null +++ b/third_party/rust/tokio-util/src/cfg.rs @@ -0,0 +1,71 @@ +macro_rules! cfg_codec { + ($($item:item)*) => { + $( + #[cfg(feature = "codec")] + #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] + $item + )* + } +} + +macro_rules! cfg_compat { + ($($item:item)*) => { + $( + #[cfg(feature = "compat")] + #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(all(feature = "net", feature = "codec"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] + $item + )* + } +} + +macro_rules! cfg_io { + ($($item:item)*) => { + $( + #[cfg(feature = "io")] + #[cfg_attr(docsrs, doc(cfg(feature = "io")))] + $item + )* + } +} + +cfg_io! { + macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } + } +} + +macro_rules! cfg_rt { + ($($item:item)*) => { + $( + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + $item + )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} diff --git a/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs new file mode 100644 index 0000000000..3dbfd456b0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/any_delimiter_codec.rs @@ -0,0 +1,263 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r"; +const DEFAULT_SEQUENCE_WRITER: &[u8] = b","; +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// Decode string of bytes containing various different delimiters. +/// +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// use tokio_util::codec::{AnyDelimiterCodec, Decoder}; +/// use bytes::{BufMut, BytesMut}; +/// +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec()); +/// let buf = &mut BytesMut::new(); +/// buf.reserve(200); +/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); +/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!(None, codec.decode(buf).unwrap()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct AnyDelimiterCodec { + // Stored index of the next index to examine for the delimiter character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde}`, the method will + // only look at `de}` before returning. + next_index: usize, + + /// The maximum length for a given chunk. If `usize::MAX`, chunks will be + /// read until a delimiter character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a chunk which was over + /// the length limit? + is_discarding: bool, + + /// The bytes that are using for search during decode + seek_delimiters: Vec<u8>, + + /// The bytes that are using for encoding + sequence_writer: Vec<u8>, +} + +impl AnyDelimiterCodec { + /// Returns a `AnyDelimiterCodec` for splitting up data into chunks. + /// + /// # Note + /// + /// The returned `AnyDelimiterCodec` will not have an upper bound on the length + /// of a buffered chunk. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length() + pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec { + AnyDelimiterCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + seek_delimiters, + sequence_writer, + } + } + + /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit. + /// + /// If this is set, calls to `AnyDelimiterCodec::decode` will return a + /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that chunk until a delimiter + /// character is reached, returning `None` until the delimiter over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the chunk currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any delimiter characters, causing unbounded memory consumption. + /// + /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError + pub fn new_with_max_length( + seek_delimiters: Vec<u8>, + sequence_writer: Vec<u8>, + max_length: usize, + ) -> Self { + AnyDelimiterCodec { + max_length, + ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer) + } + } + + /// Returns the maximum chunk length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec()); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +impl Decoder for AnyDelimiterCodec { + type Item = Bytes; + type Error = AnyDelimiterCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> { + loop { + // Determine how far into the buffer we'll search for a delimiter. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| { + self.seek_delimiters + .iter() + .any(|delimiter| *b == *delimiter) + }); + + match (self.is_discarding, new_chunk_offset) { + (true, Some(offset)) => { + // If we found a new chunk, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a chunk normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a new chunk, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a new chunk. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a chunk! + let new_chunk_index = offset + self.next_index; + self.next_index = 0; + let mut chunk = buf.split_to(new_chunk_index + 1); + chunk.truncate(chunk.len() - 1); + let chunk = chunk.freeze(); + return Ok(Some(chunk)); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // new chunk, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded); + } + (false, None) => { + // We didn't find a chunk or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // return remaining data, if any + if buf.is_empty() { + None + } else { + let chunk = buf.split_to(buf.len()); + self.next_index = 0; + Some(chunk.freeze()) + } + } + }) + } +} + +impl<T> Encoder<T> for AnyDelimiterCodec +where + T: AsRef<str>, +{ + type Error = AnyDelimiterCodecError; + + fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> { + let chunk = chunk.as_ref(); + buf.reserve(chunk.len() + 1); + buf.put(chunk.as_bytes()); + buf.put(self.sequence_writer.as_ref()); + + Ok(()) + } +} + +impl Default for AnyDelimiterCodec { + fn default() -> Self { + Self::new( + DEFAULT_SEEK_DELIMITERS.to_vec(), + DEFAULT_SEQUENCE_WRITER.to_vec(), + ) + } +} + +/// An error occurred while encoding or decoding a chunk. +#[derive(Debug)] +pub enum AnyDelimiterCodecError { + /// The maximum chunk length was exceeded. + MaxChunkLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for AnyDelimiterCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AnyDelimiterCodecError::MaxChunkLengthExceeded => { + write!(f, "max chunk length exceeded") + } + AnyDelimiterCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From<io::Error> for AnyDelimiterCodecError { + fn from(e: io::Error) -> AnyDelimiterCodecError { + AnyDelimiterCodecError::Io(e) + } +} + +impl std::error::Error for AnyDelimiterCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/bytes_codec.rs b/third_party/rust/tokio-util/src/codec/bytes_codec.rs new file mode 100644 index 0000000000..ceab228b94 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/bytes_codec.rs @@ -0,0 +1,86 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{BufMut, Bytes, BytesMut}; +use std::io; + +/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// +/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// # mod hidden { +/// # #[allow(unused_imports)] +/// use tokio::fs::File; +/// # } +/// use tokio::io::AsyncRead; +/// use tokio_util::codec::{FramedRead, BytesCodec}; +/// +/// # enum File {} +/// # impl File { +/// # async fn open(_name: &str) -> Result<impl AsyncRead, std::io::Error> { +/// # use std::io::Cursor; +/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5])) +/// # } +/// # } +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let my_async_read = File::open("filename.txt").await?; +/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] +pub struct BytesCodec(()); + +impl BytesCodec { + /// Creates a new `BytesCodec` for shipping around raw bytes. + pub fn new() -> BytesCodec { + BytesCodec(()) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> { + if !buf.is_empty() { + let len = buf.len(); + Ok(Some(buf.split_to(len))) + } else { + Ok(None) + } + } +} + +impl Encoder<Bytes> for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} + +impl Encoder<BytesMut> for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} diff --git a/third_party/rust/tokio-util/src/codec/decoder.rs b/third_party/rust/tokio-util/src/codec/decoder.rs new file mode 100644 index 0000000000..c5927783d1 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/decoder.rs @@ -0,0 +1,184 @@ +use crate::codec::Framed; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use std::io; + +/// Decoding of frames via buffers. +/// +/// This trait is used when constructing an instance of [`Framed`] or +/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has +/// already been buffered in `src` and decodes the data into a stream of +/// `Self::Item` frames. +/// +/// Implementations are able to track state on `self`, which enables +/// implementing stateful streaming parsers. In many cases, though, this type +/// will simply be a unit struct (e.g. `struct HttpDecoder`). +/// +/// For some underlying data-sources, namely files and FIFOs, +/// it's possible to temporarily read 0 bytes by reaching EOF. +/// +/// In these cases `decode_eof` will be called until it signals +/// fullfillment of all closing frames by returning `Ok(None)`. +/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`] +/// will not invoke `decode` or `decode_eof` again, until data can be read +/// during a retry. +/// +/// It is up to the Decoder to keep track of a restart after an EOF, +/// and to decide how to handle such an event by, for example, +/// allowing frames to cross EOF boundaries, re-emitting opening frames, or +/// resetting the entire internal state. +/// +/// [`Framed`]: crate::codec::Framed +/// [`FramedRead`]: crate::codec::FramedRead +pub trait Decoder { + /// The type of decoded frames. + type Item; + + /// The type of unrecoverable frame decoding errors. + /// + /// If an individual message is ill-formed but can be ignored without + /// interfering with the processing of future messages, it may be more + /// useful to report the failure as an `Item`. + /// + /// `From<io::Error>` is required in the interest of making `Error` suitable + /// for returning directly from a [`FramedRead`], and to enable the default + /// implementation of `decode_eof` to yield an `io::Error` when the decoder + /// fails to consume all available data. + /// + /// Note that implementors of this trait can simply indicate `type Error = + /// io::Error` to use I/O errors as this type. + /// + /// [`FramedRead`]: crate::codec::FramedRead + type Error: From<io::Error>; + + /// Attempts to decode a frame from the provided buffer of bytes. + /// + /// This method is called by [`FramedRead`] whenever bytes are ready to be + /// parsed. The provided buffer of bytes is what's been read so far, and + /// this instance of `Decode` can determine whether an entire frame is in + /// the buffer and is ready to be returned. + /// + /// If an entire frame is available, then this instance will remove those + /// bytes from the buffer provided and return them as a decoded + /// frame. Note that removing bytes from the provided buffer doesn't always + /// necessarily copy the bytes, so this should be an efficient operation in + /// most circumstances. + /// + /// If the bytes look valid, but a frame isn't fully available yet, then + /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that + /// it needs to read some more bytes before calling this method again. + /// + /// Note that the bytes provided may be empty. If a previous call to + /// `decode` consumed all the bytes in the buffer then `decode` will be + /// called again until it returns `Ok(None)`, indicating that more bytes need to + /// be read. + /// + /// Finally, if the bytes in the buffer are malformed then an error is + /// returned indicating why. This informs [`Framed`] that the stream is now + /// corrupt and should be terminated. + /// + /// [`Framed`]: crate::codec::Framed + /// [`FramedRead`]: crate::codec::FramedRead + /// + /// # Buffer management + /// + /// Before returning from the function, implementations should ensure that + /// the buffer has appropriate capacity in anticipation of future calls to + /// `decode`. Failing to do so leads to inefficiency. + /// + /// For example, if frames have a fixed length, or if the length of the + /// current frame is known from a header, a possible buffer management + /// strategy is: + /// + /// ```no_run + /// # use std::io; + /// # + /// # use bytes::BytesMut; + /// # use tokio_util::codec::Decoder; + /// # + /// # struct MyCodec; + /// # + /// impl Decoder for MyCodec { + /// // ... + /// # type Item = BytesMut; + /// # type Error = io::Error; + /// + /// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + /// // ... + /// + /// // Reserve enough to complete decoding of the current frame. + /// let current_frame_len: usize = 1000; // Example. + /// // And to start decoding the next frame. + /// let next_frame_header_len: usize = 10; // Example. + /// src.reserve(current_frame_len + next_frame_header_len); + /// + /// return Ok(None); + /// } + /// } + /// ``` + /// + /// An optimal buffer management strategy minimizes reallocations and + /// over-allocations. + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>; + + /// A default method available to be called when there are no more bytes + /// available to be read from the underlying I/O. + /// + /// This method defaults to calling `decode` and returns an error if + /// `Ok(None)` is returned while there is unconsumed data in `buf`. + /// Typically this doesn't need to be implemented unless the framing + /// protocol differs near the end of the stream, or if you need to construct + /// frames _across_ eof boundaries on sources that can be resumed. + /// + /// Note that the `buf` argument may be empty. If a previous call to + /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be + /// called again until it returns `None`, indicating that there are no more + /// frames to yield. This behavior enables returning finalization frames + /// that may not be based on inbound data. + /// + /// Once `None` has been returned, `decode_eof` won't be called again until + /// an attempt to resume the stream has been made, where the underlying stream + /// actually returned more data. + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the [`Framed`] returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Framed`]: crate::codec::Framed + fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self> + where + Self: Sized, + { + Framed::new(io, self) + } +} diff --git a/third_party/rust/tokio-util/src/codec/encoder.rs b/third_party/rust/tokio-util/src/codec/encoder.rs new file mode 100644 index 0000000000..770a10fa9b --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/encoder.rs @@ -0,0 +1,25 @@ +use bytes::BytesMut; +use std::io; + +/// Trait of helper objects to write out messages as bytes, for use with +/// [`FramedWrite`]. +/// +/// [`FramedWrite`]: crate::codec::FramedWrite +pub trait Encoder<Item> { + /// The type of encoding errors. + /// + /// [`FramedWrite`] requires `Encoder`s errors to implement `From<io::Error>` + /// in the interest letting it return `Error`s directly. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + type Error: From<io::Error>; + + /// Encodes a frame into the buffer provided. + /// + /// This method will encode `item` into the byte buffer provided by `dst`. + /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and + /// will be written out when possible. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/third_party/rust/tokio-util/src/codec/framed.rs b/third_party/rust/tokio-util/src/codec/framed.rs new file mode 100644 index 0000000000..d89b8b6dc3 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed.rs @@ -0,0 +1,373 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using + /// the `Encoder` and `Decoder` traits to encode and decode frames. + /// + /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or + /// by using the `new` function seen below. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Decoder::framed`]: crate::codec::Decoder::framed() + pub struct Framed<T, U> { + #[pin] + inner: FramedImpl<T, U, RWFrames> + } +} + +impl<T, U> Framed<T, U> +where + T: AsyncRead + AsyncWrite, +{ + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// Note that, for some byte sources, the stream can be resumed after an EOF + /// by reading from it, even after it has returned `None`. Repeated attempts + /// to do so, without new data available, continue to return `None` without + /// creating more (closing) frames. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn new(inner: T, codec: U) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner, + codec, + state: Default::default(), + }, + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data, + /// with a specific read buffer initial capacity. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner, + codec, + state: RWFrames { + read: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + write: WriteFrame::default(), + }, + }, + } + } +} + +impl<T, U> Framed<T, U> { + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// This objects takes a stream and a readbuffer and a writebuffer. These field + /// can be obtained from an existing `Framed` with the [`into_parts`] method. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decoder`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`into_parts`]: crate::codec::Framed::into_parts() + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> { + Framed { + inner: FramedImpl { + inner: parts.io, + codec: parts.codec, + state: RWFrames { + read: parts.read_buf.into(), + write: parts.write_buf.into(), + }, + }, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &U { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut U { + &mut self.inner.codec + } + + /// Maps the codec `U` to `C`, preserving the read and write buffers + /// wrapped by `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn map_codec<C, F>(self, map: F) -> Framed<T, C> + where + F: FnOnce(U) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let parts = self.into_parts(); + Framed::from_parts(FramedParts { + io: parts.io, + codec: map(parts.codec), + read_buf: parts.read_buf, + write_buf: parts.write_buf, + _priv: (), + }) + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.read.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.read.buffer + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.write.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.write.buffer + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Consumes the `Framed`, returning its underlying I/O stream, the buffer + /// with unprocessed data, and the codec. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_parts(self) -> FramedParts<T, U> { + FramedParts { + io: self.inner.inner, + codec: self.inner.codec, + read_buf: self.inner.state.read.buffer, + write_buf: self.inner.state.write.buffer, + _priv: (), + } + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, U> Stream for Framed<T, U> +where + T: AsyncRead, + U: Decoder, +{ + type Item = Result<U::Item, U::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, I, U> Sink<I> for Framed<T, U> +where + T: AsyncWrite, + U: Encoder<I>, + U::Error: From<io::Error>, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} + +impl<T, U> fmt::Debug for Framed<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Framed") + .field("io", self.get_ref()) + .field("codec", self.codec()) + .finish() + } +} + +/// `FramedParts` contains an export of the data of a Framed transport. +/// It can be used to construct a new [`Framed`] with a different codec. +/// It contains all current buffers and the inner transport. +/// +/// [`Framed`]: crate::codec::Framed +#[derive(Debug)] +#[allow(clippy::manual_non_exhaustive)] +pub struct FramedParts<T, U> { + /// The inner transport used to read bytes to and write bytes to + pub io: T, + + /// The codec + pub codec: U, + + /// The buffer with read but unprocessed data. + pub read_buf: BytesMut, + + /// A buffer with unprocessed data which are not written yet. + pub write_buf: BytesMut, + + /// This private field allows us to add additional fields in the future in a + /// backwards compatible way. + _priv: (), +} + +impl<T, U> FramedParts<T, U> { + /// Create a new, default, `FramedParts` + pub fn new<I>(io: T, codec: U) -> FramedParts<T, U> + where + U: Encoder<I>, + { + FramedParts { + io, + codec, + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + _priv: (), + } + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_impl.rs b/third_party/rust/tokio-util/src/codec/framed_impl.rs new file mode 100644 index 0000000000..ce1a6db873 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_impl.rs @@ -0,0 +1,308 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_core::ready; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::borrow::{Borrow, BorrowMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tracing::trace; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FramedImpl<T, U, State> { + #[pin] + pub(crate) inner: T, + pub(crate) state: State, + pub(crate) codec: U, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +#[derive(Debug)] +pub(crate) struct ReadFrame { + pub(crate) eof: bool, + pub(crate) is_readable: bool, + pub(crate) buffer: BytesMut, + pub(crate) has_errored: bool, +} + +pub(crate) struct WriteFrame { + pub(crate) buffer: BytesMut, +} + +#[derive(Default)] +pub(crate) struct RWFrames { + pub(crate) read: ReadFrame, + pub(crate) write: WriteFrame, +} + +impl Default for ReadFrame { + fn default() -> Self { + Self { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + has_errored: false, + } + } +} + +impl Default for WriteFrame { + fn default() -> Self { + Self { + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl From<BytesMut> for ReadFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + is_readable: size > 0, + eof: false, + has_errored: false, + } + } +} + +impl From<BytesMut> for WriteFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { buffer } + } +} + +impl Borrow<ReadFrame> for RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut<ReadFrame> for RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow<WriteFrame> for RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut<WriteFrame> for RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} +impl<T, U, R> Stream for FramedImpl<T, U, R> +where + T: AsyncRead, + U: Decoder, + R: BorrowMut<ReadFrame>, +{ + type Item = Result<U::Item, U::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + use crate::util::poll_read_buf; + + let mut pinned = self.project(); + let state: &mut ReadFrame = pinned.state.borrow_mut(); + // The following loops implements a state machine with each state corresponding + // to a combination of the `is_readable` and `eof` flags. States persist across + // loop entries and most state transitions occur with a return. + // + // The initial state is `reading`. + // + // | state | eof | is_readable | has_errored | + // |---------|-------|-------------|-------------| + // | reading | false | false | false | + // | framing | false | true | false | + // | pausing | true | true | false | + // | paused | true | false | false | + // | errored | <any> | <any> | true | + // `decode_eof` returns Err + // ┌────────────────────────────────────────────────────────┐ + // `decode_eof` returns │ │ + // `Ok(Some)` │ │ + // ┌─────┐ │ `decode_eof` returns After returning │ + // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ + // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ + // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ + // Pending read │ │ │ │ │ │ + // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ + // │ │ │ ┌──────┐ │ Pending │ │ + // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ + // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ + // └──┬─▲────┘ └─────┬──┬┘ │ │ + // │ │ │ │ `decode` returns Err │ │ + // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ + // │ read returns Err │ + // └────────────────────────────────────────────────────────────────────────────────────────────┘ + loop { + // Return `None` if we have encountered an error from the underlying decoder + // See: https://github.com/tokio-rs/tokio/issues/3976 + if state.has_errored { + // preparing has_errored -> paused + trace!("Returning None and setting paused"); + state.is_readable = false; + state.has_errored = false; + return Poll::Ready(None); + } + + // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", + // i.e. it _might_ contain data consumable as a frame or closing frame. + // Both signal that there is no such data by returning `None`. + // + // If `decode` couldn't read a frame and the upstream source has returned eof, + // `decode_eof` will attempt to decode the remaining bytes as closing frames. + // + // If the underlying AsyncRead is resumable, we may continue after an EOF, + // but must finish emitting all of it's associated `decode_eof` frames. + // Furthermore, we don't want to emit any `decode_eof` frames on retried + // reads after an EOF unless we've actually read more data. + if state.is_readable { + // pausing or framing + if state.eof { + // pausing + let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + })?; + if frame.is_none() { + state.is_readable = false; // prepare pausing -> paused + } + // implicit pausing -> pausing or pausing -> paused + return Poll::Ready(frame.map(Ok)); + } + + // framing + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + op + })? { + trace!("frame decoded from buffer"); + // implicit framing -> framing + return Poll::Ready(Some(Ok(frame))); + } + + // framing -> reading + state.is_readable = false; + } + // reading or paused + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + state.buffer.reserve(1); + let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( + |err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + }, + )? { + Poll::Ready(ct) => ct, + // implicit reading -> reading or implicit paused -> paused + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + if state.eof { + // We're already at an EOF, and since we've reached this path + // we're also not readable. This implies that we've already finished + // our `decode_eof` handling, so we can simply return `None`. + // implicit paused -> paused + return Poll::Ready(None); + } + // prepare reading -> paused + state.eof = true; + } else { + // prepare paused -> framing or noop reading -> framing + state.eof = false; + } + + // paused -> framing or reading -> framing or reading -> pausing + state.is_readable = true; + } + } +} + +impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> +where + T: AsyncWrite, + U: Encoder<I>, + U::Error: From<io::Error>, + W: BorrowMut<WriteFrame>, +{ + type Error = U::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { + self.as_mut().poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned + .codec + .encode(item, &mut pinned.state.borrow_mut().buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + use crate::util::poll_write_buf; + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.state.borrow_mut().buffer.is_empty() { + let WriteFrame { buffer } = pinned.state.borrow_mut(); + trace!(remaining = buffer.len(), "writing;"); + + let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_read.rs b/third_party/rust/tokio-util/src/codec/framed_read.rs new file mode 100644 index 0000000000..184c567b49 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_read.rs @@ -0,0 +1,199 @@ +use crate::codec::framed_impl::{FramedImpl, ReadFrame}; +use crate::codec::Decoder; + +use futures_core::Stream; +use tokio::io::AsyncRead; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Stream`] of messages decoded from an [`AsyncRead`]. + /// + /// [`Stream`]: futures_core::Stream + /// [`AsyncRead`]: tokio::io::AsyncRead + pub struct FramedRead<T, D> { + #[pin] + inner: FramedImpl<T, D, ReadFrame>, + } +} + +// ===== impl FramedRead ===== + +impl<T, D> FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + /// Creates a new `FramedRead` with the given `decoder`. + pub fn new(inner: T, decoder: D) -> FramedRead<T, D> { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: Default::default(), + }, + } + } + + /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity` + /// initial size. + pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead<T, D> { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + }, + } + } +} + +impl<T, D> FramedRead<T, D> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedRead`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying decoder. + pub fn decoder(&self) -> &D { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_mut(&mut self) -> &mut D { + &mut self.inner.codec + } + + /// Maps the decoder `D` to `C`, preserving the read buffer + /// wrapped by `Framed`. + pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C> + where + F: FnOnce(D) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedRead { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, D> Stream for FramedRead<T, D> +where + T: AsyncRead, + D: Decoder, +{ + type Item = Result<D::Item, D::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl<T, I, D> Sink<I> for FramedRead<T, D> +where + T: Sink<I>, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.project().inner.poll_close(cx) + } +} + +impl<T, D> fmt::Debug for FramedRead<T, D> +where + T: fmt::Debug, + D: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedRead") + .field("inner", &self.get_ref()) + .field("decoder", &self.decoder()) + .field("eof", &self.inner.state.eof) + .field("is_readable", &self.inner.state.is_readable) + .field("buffer", &self.read_buffer()) + .finish() + } +} diff --git a/third_party/rust/tokio-util/src/codec/framed_write.rs b/third_party/rust/tokio-util/src/codec/framed_write.rs new file mode 100644 index 0000000000..aa4cec9820 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/framed_write.rs @@ -0,0 +1,178 @@ +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; + +use futures_core::Stream; +use tokio::io::AsyncWrite; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Sink`] of frames encoded to an `AsyncWrite`. + /// + /// [`Sink`]: futures_sink::Sink + pub struct FramedWrite<T, E> { + #[pin] + inner: FramedImpl<T, E, WriteFrame>, + } +} + +impl<T, E> FramedWrite<T, E> +where + T: AsyncWrite, +{ + /// Creates a new `FramedWrite` with the given `encoder`. + pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> { + FramedWrite { + inner: FramedImpl { + inner, + codec: encoder, + state: WriteFrame::default(), + }, + } + } +} + +impl<T, E> FramedWrite<T, E> { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedWrite`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying encoder. + pub fn encoder(&self) -> &E { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_mut(&mut self) -> &mut E { + &mut self.inner.codec + } + + /// Maps the encoder `E` to `C`, preserving the write buffer + /// wrapped by `Framed`. + pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C> + where + F: FnOnce(E) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedWrite { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { + self.project().inner.project().codec + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } +} + +// This impl just defers to the underlying FramedImpl +impl<T, I, E> Sink<I> for FramedWrite<T, E> +where + T: AsyncWrite, + E: Encoder<I>, + E::Error: From<io::Error>, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.project().inner.poll_close(cx) + } +} + +// This impl just defers to the underlying T: Stream +impl<T, D> Stream for FramedWrite<T, D> +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().inner.project().inner.poll_next(cx) + } +} + +impl<T, U> fmt::Debug for FramedWrite<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedWrite") + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) + .finish() + } +} diff --git a/third_party/rust/tokio-util/src/codec/length_delimited.rs b/third_party/rust/tokio-util/src/codec/length_delimited.rs new file mode 100644 index 0000000000..93d2f180d0 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/length_delimited.rs @@ -0,0 +1,1047 @@ +//! Frame a stream of bytes based on a length prefix +//! +//! Many protocols delimit their frames by prefacing frame data with a +//! frame head that specifies the length of the frame. The +//! `length_delimited` module provides utilities for handling the length +//! based framing. This allows the consumer to work with entire frames +//! without having to worry about buffering or other framing logic. +//! +//! # Getting started +//! +//! If implementing a protocol from scratch, using length delimited framing +//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will +//! return a length delimited codec using default configuration values. +//! This can then be used to construct a framer to adapt a full-duplex +//! byte stream into a stream of frames. +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! fn bind_transport<T: AsyncRead + AsyncWrite>(io: T) +//! -> Framed<T, LengthDelimitedCodec> +//! { +//! Framed::new(io, LengthDelimitedCodec::new()) +//! } +//! # pub fn main() {} +//! ``` +//! +//! The returned transport implements `Sink + Stream` for `BytesMut`. It +//! encodes the frame with a big-endian `u32` header denoting the frame +//! payload length: +//! +//! ```text +//! +----------+--------------------------------+ +//! | len: u32 | frame payload | +//! +----------+--------------------------------+ +//! ``` +//! +//! Specifically, given the following: +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! use futures::SinkExt; +//! use bytes::Bytes; +//! +//! async fn write_frame<T>(io: T) -> Result<(), Box<dyn std::error::Error>> +//! where +//! T: AsyncRead + AsyncWrite + Unpin, +//! { +//! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); +//! let frame = Bytes::from("hello world"); +//! +//! transport.send(frame).await?; +//! Ok(()) +//! } +//! ``` +//! +//! The encoded frame will look like this: +//! +//! ```text +//! +---- len: u32 ----+---- data ----+ +//! | \x00\x00\x00\x0b | hello world | +//! +------------------+--------------+ +//! ``` +//! +//! # Decoding +//! +//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], +//! such that each yielded [`BytesMut`] value contains the contents of an +//! entire frame. There are many configuration parameters enabling +//! [`FramedRead`] to handle a wide range of protocols. Here are some +//! examples that will cover the various options at a high level. +//! +//! ## Example 1 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(0) // default value +//! .num_skip(0) // Do not strip frame header +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! The value of the length field is 11 (`\x0B`) which represents the length +//! of the payload, `hello world`. By default, [`FramedRead`] assumes that +//! the length field represents the number of bytes that **follows** the +//! length field. Thus, the entire frame has a length of 13: 2 bytes for the +//! frame head + 11 bytes for the payload. +//! +//! ## Example 2 +//! +//! The following will parse a `u16` length field at offset 0, omitting the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(0) // default value +//! // `num_skip` is not needed, the default is to skip +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +--- Payload ---+ +//! | \x00\x0B | Hello world | --> | Hello world | +//! +----------+---------------+ +---------------+ +//! ``` +//! +//! This is similar to the first example, the only difference is that the +//! frame head is **not** included in the yielded `BytesMut` value. +//! +//! ## Example 3 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. In this case, the length field +//! **includes** the frame head length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::<u16>() +//! .length_adjustment(-2) // size of head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! In most cases, the length field represents the length of the payload +//! only, as shown in the previous examples. However, in some protocols the +//! length field represents the length of the whole frame, including the +//! head. In such cases, we specify a negative `length_adjustment` to adjust +//! the value provided in the frame head to represent the payload length. +//! +//! ## Example 4 +//! +//! The following will parse a 3 byte length field at offset 0 in a 5 byte +//! frame head, including the frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(2) // remaining head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! +//! DECODED +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! ``` +//! +//! A more advanced example that shows a case where there is extra frame +//! head data between the length field and the payload. In such cases, it is +//! usually desirable to include the frame head as part of the yielded +//! `BytesMut`. This lets consumers of the length delimited framer to +//! process the frame head as needed. +//! +//! The positive `length_adjustment` value lets `FramedRead` factor in the +//! additional head into the frame length calculation. +//! +//! ## Example 5 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::<u16>() +//! .length_adjustment(1) // length of hdr2 +//! .num_skip(3) // length of hdr1 + LEN +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0B | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! The length field is situated in the middle of the frame head. In this +//! case, the first byte in the frame head could be a version or some other +//! identifier that is not needed for processing. On the other hand, the +//! second half of the head is needed. +//! +//! `length_field_offset` indicates how many bytes to skip before starting +//! to read the length field. `length_adjustment` is the number of bytes to +//! skip starting at the end of the length field. In this case, it is the +//! second half of the head. +//! +//! ## Example 6 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. In this case, the length field **includes** the frame head +//! length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::<u16>() +//! .length_adjustment(-3) // length of hdr1 + LEN, negative +//! .num_skip(3) +//! .new_read(io); +//! # } +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0F | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! Similar to the example above, the difference is that the length field +//! represents the length of the entire frame instead of just the payload. +//! The length of `hdr1` and `len` must be counted in `length_adjustment`. +//! Note that the length of `hdr2` does **not** need to be explicitly set +//! anywhere because it already is factored into the total frame length that +//! is read from the byte stream. +//! +//! ## Example 7 +//! +//! The following will parse a 3 byte length field at offset 0 in a 4 byte +//! frame head, excluding the 4th byte from the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read<T: AsyncRead>(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(0) // default value +//! .num_skip(4) // skip the first 4 bytes +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +------- len ------+--- Payload ---+ +--- Payload ---+ +//! | \x00\x00\x0B\xFF | Hello world | => | Hello world | +//! +------------------+---------------+ +---------------+ +//! ``` +//! +//! A simple example where there are unused bytes between the length field +//! and the payload. +//! +//! # Encoding +//! +//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], +//! such that each submitted [`BytesMut`] is prefaced by a length field. +//! There are fewer configuration options than [`FramedRead`]. Given +//! protocols that have more complex frame heads, an encoder should probably +//! be written by hand using [`Encoder`]. +//! +//! Here is a simple example, given a `FramedWrite` with the following +//! configuration: +//! +//! ``` +//! # use tokio::io::AsyncWrite; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn write_frame<T: AsyncWrite>(io: T) { +//! # let _ = +//! LengthDelimitedCodec::builder() +//! .length_field_type::<u16>() +//! .new_write(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! A payload of `hello world` will be encoded as: +//! +//! ```text +//! +- len: u16 -+---- data ----+ +//! | \x00\x0b | hello world | +//! +------------+--------------+ +//! ``` +//! +//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new +//! [`FramedRead`]: struct@FramedRead +//! [`FramedWrite`]: struct@FramedWrite +//! [`AsyncRead`]: trait@tokio::io::AsyncRead +//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite +//! [`Encoder`]: trait@Encoder +//! [`BytesMut`]: bytes::BytesMut + +use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::error::Error as StdError; +use std::io::{self, Cursor}; +use std::{cmp, fmt, mem}; + +/// Configure length delimited `LengthDelimitedCodec`s. +/// +/// `Builder` enables constructing configured length delimited codecs. Note +/// that not all configuration settings apply to both encoding and decoding. See +/// the documentation for specific methods for more detail. +#[derive(Debug, Clone, Copy)] +pub struct Builder { + // Maximum frame length + max_frame_len: usize, + + // Number of bytes representing the field length + length_field_len: usize, + + // Number of bytes in the header before the length field + length_field_offset: usize, + + // Adjust the length specified in the header field by this amount + length_adjustment: isize, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: Option<usize>, + + // Length field byte order (little or big endian) + length_field_is_big_endian: bool, +} + +/// An error when the number of bytes read is more than max frame length. +pub struct LengthDelimitedCodecError { + _priv: (), +} + +/// A codec for frames delimited by a frame head specifying their lengths. +/// +/// This allows the consumer to work with entire frames without having to worry +/// about buffering or other framing logic. +/// +/// See [module level] documentation for more detail. +/// +/// [module level]: index.html +#[derive(Debug, Clone)] +pub struct LengthDelimitedCodec { + // Configuration values + builder: Builder, + + // Read state + state: DecodeState, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +// ===== impl LengthDelimitedCodec ====== + +impl LengthDelimitedCodec { + /// Creates a new `LengthDelimitedCodec` with the default configuration values. + pub fn new() -> Self { + Self { + builder: Builder::new(), + state: DecodeState::Head, + } + } + + /// Creates a new length delimited codec builder with default configuration + /// values. + pub fn builder() -> Builder { + Builder::new() + } + + /// Returns the current max frame setting + /// + /// This is the largest size this codec will accept from the wire. Larger + /// frames will be rejected. + pub fn max_frame_length(&self) -> usize { + self.builder.max_frame_len + } + + /// Updates the max frame setting. + /// + /// The change takes effect the next time a frame is decoded. In other + /// words, if a frame is currently in process of being decoded with a frame + /// size greater than `val` but less than the max frame length in effect + /// before calling this function, then the frame will be allowed. + pub fn set_max_frame_length(&mut self, val: usize) { + self.builder.max_frame_length(val); + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> { + let head_len = self.builder.num_head_bytes(); + let field_len = self.builder.length_field_len; + + if src.len() < head_len { + // Not enough data + return Ok(None); + } + + let n = { + let mut src = Cursor::new(&mut *src); + + // Skip the required bytes + src.advance(self.builder.length_field_offset); + + // match endianness + let n = if self.builder.length_field_is_big_endian { + src.get_uint(field_len) + } else { + src.get_uint_le(field_len) + }; + + if n > self.builder.max_frame_len as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // The check above ensures there is no overflow + let n = n as usize; + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_sub(-self.builder.length_adjustment as usize) + } else { + n.checked_add(self.builder.length_adjustment as usize) + }; + + // Error handling + match n { + Some(n) => n, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + )); + } + } + }; + + let num_skip = self.builder.get_num_skip(); + + if num_skip > 0 { + src.advance(num_skip); + } + + // Ensure that the buffer has enough space to read the incoming + // payload + src.reserve(n); + + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option<BytesMut> { + // At this point, the buffer has already had the required capacity + // reserved. All there is to do is read. + if src.len() < n { + return None; + } + + Some(src.split_to(n)) + } +} + +impl Decoder for LengthDelimitedCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src) { + Some(data) => { + // Update the decode state + self.state = DecodeState::Head; + + // Make sure the buffer has enough space to read the next head + src.reserve(self.builder.num_head_bytes()); + + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder<Bytes> for LengthDelimitedCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { + let n = data.len(); + + if n > self.builder.max_frame_len { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_add(-self.builder.length_adjustment as usize) + } else { + n.checked_sub(self.builder.length_adjustment as usize) + }; + + let n = n.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + ) + })?; + + // Reserve capacity in the destination buffer to fit the frame and + // length field (plus adjustment). + dst.reserve(self.builder.length_field_len + n); + + if self.builder.length_field_is_big_endian { + dst.put_uint(n as u64, self.builder.length_field_len); + } else { + dst.put_uint_le(n as u64, self.builder.length_field_len); + } + + // Write the frame to the buffer + dst.extend_from_slice(&data[..]); + + Ok(()) + } +} + +impl Default for LengthDelimitedCodec { + fn default() -> Self { + Self::new() + } +} + +// ===== impl Builder ===== + +mod builder { + /// Types that can be used with `Builder::length_field_type`. + pub trait LengthFieldType {} + + impl LengthFieldType for u8 {} + impl LengthFieldType for u16 {} + impl LengthFieldType for u32 {} + impl LengthFieldType for u64 {} + + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32", + target_pointer_width = "64", + ))] + impl LengthFieldType for usize {} +} + +impl Builder { + /// Creates a new length delimited codec builder with default configuration + /// values. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new() -> Builder { + Builder { + // Default max frame length of 8MB + max_frame_len: 8 * 1_024 * 1_024, + + // Default byte length of 4 + length_field_len: 4, + + // Default to the header field being at the start of the header. + length_field_offset: 0, + + length_adjustment: 0, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: None, + + // Default to reading the length field in network (big) endian. + length_field_is_big_endian: true, + } + } + + /// Read the length field as a big endian integer + /// + /// This is the default setting. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .big_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn big_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = true; + self + } + + /// Read the length field as a little endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .little_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn little_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = false; + self + } + + /// Read the length field as a native endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .native_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn native_endian(&mut self) -> &mut Self { + if cfg!(target_endian = "big") { + self.big_endian() + } else { + self.little_endian() + } + } + + /// Sets the max frame length in bytes + /// + /// This configuration option applies to both encoding and decoding. The + /// default value is 8MB. + /// + /// When decoding, the length field read from the byte stream is checked + /// against this setting **before** any adjustments are applied. When + /// encoding, the length of the submitted payload is checked against this + /// setting. + /// + /// When frames exceed the max length, an `io::Error` with the custom value + /// of the `LengthDelimitedCodecError` type will be returned. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .max_frame_length(8 * 1024 * 1024) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn max_frame_length(&mut self, val: usize) -> &mut Self { + self.max_frame_len = val; + self + } + + /// Sets the unsigned integer type used to represent the length field. + /// + /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on + /// 64-bit targets). + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u32>() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + /// + /// Unlike [`Builder::length_field_length`], this does not fail at runtime + /// and instead produces a compile error: + /// + /// ```compile_fail + /// # use tokio::io::AsyncRead; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u128>() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_type<T: builder::LengthFieldType>(&mut self) -> &mut Self { + self.length_field_length(mem::size_of::<T>()) + } + + /// Sets the number of bytes used to represent the length field + /// + /// The default value is `4`. The max value is `8`. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_length(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_length(&mut self, val: usize) -> &mut Self { + assert!(val > 0 && val <= 8, "invalid length field length"); + self.length_field_len = val; + self + } + + /// Sets the number of bytes in the header before the length field + /// + /// This configuration option only applies to decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(1) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_offset(&mut self, val: usize) -> &mut Self { + self.length_field_offset = val; + self + } + + /// Delta between the payload length specified in the header and the real + /// payload length + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_adjustment(-2) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_adjustment(&mut self, val: isize) -> &mut Self { + self.length_adjustment = val; + self + } + + /// Sets the number of bytes to skip before reading the payload + /// + /// Default value is `length_field_len + length_field_offset` + /// + /// This configuration option only applies to decoding + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .num_skip(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn num_skip(&mut self, val: usize) -> &mut Self { + self.num_skip = Some(val); + self + } + + /// Create a configured length delimited `LengthDelimitedCodec` + /// + /// # Examples + /// + /// ``` + /// use tokio_util::codec::LengthDelimitedCodec; + /// # pub fn main() { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_codec(); + /// # } + /// ``` + pub fn new_codec(&self) -> LengthDelimitedCodec { + LengthDelimitedCodec { + builder: *self, + state: DecodeState::Head, + } + } + + /// Create a configured length delimited `FramedRead` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::<u16>() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec> + where + T: AsyncRead, + { + FramedRead::new(upstream, self.new_codec()) + } + + /// Create a configured length delimited `FramedWrite` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncWrite; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncWrite>(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u16>() + /// .new_write(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec> + where + T: AsyncWrite, + { + FramedWrite::new(inner, self.new_codec()) + } + + /// Create a configured length delimited `Framed` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) { + /// # let _ = + /// LengthDelimitedCodec::builder() + /// .length_field_type::<u16>() + /// .new_framed(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec> + where + T: AsyncRead + AsyncWrite, + { + Framed::new(inner, self.new_codec()) + } + + fn num_head_bytes(&self) -> usize { + let num = self.length_field_offset + self.length_field_len; + cmp::max(num, self.num_skip.unwrap_or(0)) + } + + fn get_num_skip(&self) -> usize { + self.num_skip + .unwrap_or(self.length_field_offset + self.length_field_len) + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +// ===== impl LengthDelimitedCodecError ===== + +impl fmt::Debug for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LengthDelimitedCodecError").finish() + } +} + +impl fmt::Display for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("frame size too big") + } +} + +impl StdError for LengthDelimitedCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/lines_codec.rs b/third_party/rust/tokio-util/src/codec/lines_codec.rs new file mode 100644 index 0000000000..7a0a8f0454 --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/lines_codec.rs @@ -0,0 +1,230 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LinesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl LinesCodec { + /// Returns a `LinesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() + pub fn new() -> LinesCodec { + LinesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + /// Returns a `LinesCodec` with a maximum line length limit. + /// + /// If this is set, calls to `LinesCodec::decode` will return a + /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that line until a newline + /// character is reached, returning `None` until the line over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `LinesCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the line currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any `\n` characters, causing unbounded memory consumption. + /// + /// [`LinesCodecError`]: crate::codec::LinesCodecError + pub fn new_with_max_length(max_length: usize) -> Self { + LinesCodec { + max_length, + ..LinesCodec::new() + } + } + + /// Returns the maximum line length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new(); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new_with_max_length(256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn utf8(buf: &[u8]) -> Result<&str, io::Error> { + str::from_utf8(buf) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +impl Decoder for LinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + let line = utf8(line)?; + return Ok(Some(line.to_string())); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let line = buf.split_to(buf.len()); + let line = without_carriage_return(&line); + let line = utf8(line)?; + self.next_index = 0; + Some(line.to_string()) + } + } + }) + } +} + +impl<T> Encoder<T> for LinesCodec +where + T: AsRef<str>, +{ + type Error = LinesCodecError; + + fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + let line = line.as_ref(); + buf.reserve(line.len() + 1); + buf.put(line.as_bytes()); + buf.put_u8(b'\n'); + Ok(()) + } +} + +impl Default for LinesCodec { + fn default() -> Self { + Self::new() + } +} + +/// An error occurred while encoding or decoding a line. +#[derive(Debug)] +pub enum LinesCodecError { + /// The maximum line length was exceeded. + MaxLineLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From<io::Error> for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} diff --git a/third_party/rust/tokio-util/src/codec/mod.rs b/third_party/rust/tokio-util/src/codec/mod.rs new file mode 100644 index 0000000000..2295176bdc --- /dev/null +++ b/third_party/rust/tokio-util/src/codec/mod.rs @@ -0,0 +1,290 @@ +//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink +//! +//! Raw I/O objects work with byte sequences, but higher-level code usually +//! wants to batch these into meaningful chunks, called "frames". +//! +//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and +//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. +//! Framed streams are also known as transports. +//! +//! # The Decoder trait +//! +//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an +//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify +//! how sequences of bytes are turned into a sequence of frames, and to +//! determine where the boundaries between frames are. The job of the +//! `FramedRead` is to repeatedly switch between reading more data from the IO +//! resource, and asking the decoder whether we have received enough data to +//! decode another frame of data. +//! +//! The main method on the `Decoder` trait is the [`decode`] method. This method +//! takes as argument the data that has been read so far, and when it is called, +//! it will be in one of the following situations: +//! +//! 1. The buffer contains less than a full frame. +//! 2. The buffer contains exactly a full frame. +//! 3. The buffer contains more than a full frame. +//! +//! In the first situation, the decoder should return `Ok(None)`. +//! +//! In the second situation, the decoder should clear the provided buffer and +//! return `Ok(Some(the_decoded_frame))`. +//! +//! In the third situation, the decoder should use a method such as [`split_to`] +//! or [`advance`] to modify the buffer such that the frame is removed from the +//! buffer, but any data in the buffer after that frame should still remain in +//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in +//! this case. +//! +//! Finally the decoder may return an error if the data is invalid in some way. +//! The decoder should _not_ return an error just because it has yet to receive +//! a full frame. +//! +//! It is guaranteed that, from one call to `decode` to another, the provided +//! buffer will contain the exact same data as before, except that if more data +//! has arrived through the IO resource, that data will have been appended to +//! the buffer. This means that reading frames from a `FramedRead` is +//! essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncReadExt; +//! # // This uses async_stream to create an example that compiles. +//! # fn foo() -> impl futures_core::Stream<Item = std::io::Result<bytes::BytesMut>> { async_stream::try_stream! { +//! # use tokio_util::codec::Decoder; +//! # let mut decoder = tokio_util::codec::BytesCodec::new(); +//! # let io_resource = &mut &[0u8, 1, 2, 3][..]; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! // The read_buf call will append to buf rather than overwrite existing data. +//! let len = io_resource.read_buf(&mut buf).await?; +//! +//! if len == 0 { +//! while let Some(frame) = decoder.decode_eof(&mut buf)? { +//! yield frame; +//! } +//! break; +//! } +//! +//! while let Some(frame) = decoder.decode(&mut buf)? { +//! yield frame; +//! } +//! } +//! # }} +//! ``` +//! The example above uses `yield` whenever the `Stream` produces an item. +//! +//! ## Example decoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The decoder fails with an error +//! if the string data is not valid utf-8 or too long. +//! +//! Such a decoder can be written like this: +//! ``` +//! use tokio_util::codec::Decoder; +//! use bytes::{BytesMut, Buf}; +//! +//! struct MyStringDecoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Decoder for MyStringDecoder { +//! type Item = String; +//! type Error = std::io::Error; +//! +//! fn decode( +//! &mut self, +//! src: &mut BytesMut +//! ) -> Result<Option<Self::Item>, Self::Error> { +//! if src.len() < 4 { +//! // Not enough data to read length marker. +//! return Ok(None); +//! } +//! +//! // Read length marker. +//! let mut length_bytes = [0u8; 4]; +//! length_bytes.copy_from_slice(&src[..4]); +//! let length = u32::from_le_bytes(length_bytes) as usize; +//! +//! // Check that the length is not too large to avoid a denial of +//! // service attack where the server runs out of memory. +//! if length > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", length) +//! )); +//! } +//! +//! if src.len() < 4 + length { +//! // The full string has not yet arrived. +//! // +//! // We reserve more space in the buffer. This is not strictly +//! // necessary, but is a good idea performance-wise. +//! src.reserve(4 + length - src.len()); +//! +//! // We inform the Framed that we need more bytes to form the next +//! // frame. +//! return Ok(None); +//! } +//! +//! // Use advance to modify src such that it no longer contains +//! // this frame. +//! let data = src[4..4 + length].to_vec(); +//! src.advance(4 + length); +//! +//! // Convert the data to a string, or fail if it is not valid utf-8. +//! match String::from_utf8(data) { +//! Ok(string) => Ok(Some(string)), +//! Err(utf8_error) => { +//! Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! utf8_error.utf8_error(), +//! )) +//! }, +//! } +//! } +//! } +//! ``` +//! +//! # The Encoder trait +//! +//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn +//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to +//! specify how frames are turned into a sequences of bytes. The job of the +//! `FramedWrite` is to take the resulting sequence of bytes and write it to the +//! IO resource. +//! +//! The main method on the `Encoder` trait is the [`encode`] method. This method +//! takes an item that is being written, and a buffer to write the item to. The +//! buffer may already contain data, and in this case, the encoder should append +//! the new frame the to buffer rather than overwrite the existing data. +//! +//! It is guaranteed that, from one call to `encode` to another, the provided +//! buffer will contain the exact same data as before, except that some of the +//! data may have been removed from the front of the buffer. Writing to a +//! `FramedWrite` is essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncWriteExt; +//! use bytes::Buf; // for advance +//! # use tokio_util::codec::Encoder; +//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() } +//! # async fn no_more_frames() { } +//! # #[tokio::main] async fn main() -> std::io::Result<()> { +//! # let mut io_resource = tokio::io::sink(); +//! # let mut encoder = tokio_util::codec::BytesCodec::new(); +//! +//! const MAX: usize = 8192; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! tokio::select! { +//! num_written = io_resource.write(&buf), if !buf.is_empty() => { +//! buf.advance(num_written?); +//! }, +//! frame = next_frame(), if buf.len() < MAX => { +//! encoder.encode(frame, &mut buf)?; +//! }, +//! _ = no_more_frames() => { +//! io_resource.write_all(&buf).await?; +//! io_resource.shutdown().await?; +//! return Ok(()); +//! }, +//! } +//! } +//! # } +//! ``` +//! Here the `next_frame` method corresponds to any frames you write to the +//! `FramedWrite`. The `no_more_frames` method corresponds to closing the +//! `FramedWrite` with [`SinkExt::close`]. +//! +//! ## Example encoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The encoder will fail if the +//! string is too long. +//! +//! Such an encoder can be written like this: +//! ``` +//! use tokio_util::codec::Encoder; +//! use bytes::BytesMut; +//! +//! struct MyStringEncoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Encoder<String> for MyStringEncoder { +//! type Error = std::io::Error; +//! +//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> { +//! // Don't send a string if it is longer than the other end will +//! // accept. +//! if item.len() > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", item.len()) +//! )); +//! } +//! +//! // Convert the length into a byte array. +//! // The cast to u32 cannot overflow due to the length check above. +//! let len_slice = u32::to_le_bytes(item.len() as u32); +//! +//! // Reserve space in the buffer. +//! dst.reserve(4 + item.len()); +//! +//! // Write the length and string to the buffer. +//! dst.extend_from_slice(&len_slice); +//! dst.extend_from_slice(item.as_bytes()); +//! Ok(()) +//! } +//! } +//! ``` +//! +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Stream`]: futures_core::Stream +//! [`Sink`]: futures_sink::Sink +//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close +//! [`FramedRead`]: struct@crate::codec::FramedRead +//! [`FramedWrite`]: struct@crate::codec::FramedWrite +//! [`Framed`]: struct@crate::codec::Framed +//! [`Decoder`]: trait@crate::codec::Decoder +//! [`decode`]: fn@crate::codec::Decoder::decode +//! [`encode`]: fn@crate::codec::Encoder::encode +//! [`split_to`]: fn@bytes::BytesMut::split_to +//! [`advance`]: fn@bytes::Buf::advance + +mod bytes_codec; +pub use self::bytes_codec::BytesCodec; + +mod decoder; +pub use self::decoder::Decoder; + +mod encoder; +pub use self::encoder::Encoder; + +mod framed_impl; +#[allow(unused_imports)] +pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +mod framed; +pub use self::framed::{Framed, FramedParts}; + +mod framed_read; +pub use self::framed_read::FramedRead; + +mod framed_write; +pub use self::framed_write::FramedWrite; + +pub mod length_delimited; +pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; + +mod lines_codec; +pub use self::lines_codec::{LinesCodec, LinesCodecError}; + +mod any_delimiter_codec; +pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError}; diff --git a/third_party/rust/tokio-util/src/compat.rs b/third_party/rust/tokio-util/src/compat.rs new file mode 100644 index 0000000000..6a8802d969 --- /dev/null +++ b/third_party/rust/tokio-util/src/compat.rs @@ -0,0 +1,274 @@ +//! Compatibility between the `tokio::io` and `futures-io` versions of the +//! `AsyncRead` and `AsyncWrite` traits. +use futures_core::ready; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A compatibility layer that allows conversion between the + /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. + #[derive(Copy, Clone, Debug)] + pub struct Compat<T> { + #[pin] + inner: T, + seek_pos: Option<io::SeekFrom>, + } +} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. +pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `tokio_io::AsyncRead`. + fn compat(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. +pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `tokio::io::AsyncWrite`. + fn compat_write(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. +pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncRead`. + fn compat(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. +pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncWrite`. + fn compat_write(self) -> Compat<Self> + where + Self: Sized, + { + Compat::new(self) + } +} + +impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {} + +// === impl Compat === + +impl<T> Compat<T> { + fn new(inner: T) -> Self { + Self { + inner, + seek_pos: None, + } + } + + /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Returns the wrapped item. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl<T> tokio::io::AsyncRead for Compat<T> +where + T: futures_io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // We can't trust the inner type to not peak at the bytes, + // so we must defensively initialize the buffer. + let slice = buf.initialize_unfilled(); + let n = ready!(futures_io::AsyncRead::poll_read( + self.project().inner, + cx, + slice + ))?; + buf.advance(n); + Poll::Ready(Ok(())) + } +} + +impl<T> futures_io::AsyncRead for Compat<T> +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + slice: &mut [u8], + ) -> Poll<io::Result<usize>> { + let mut buf = tokio::io::ReadBuf::new(slice); + ready!(tokio::io::AsyncRead::poll_read( + self.project().inner, + cx, + &mut buf + ))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +impl<T> tokio::io::AsyncBufRead for Compat<T> +where + T: futures_io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&'a [u8]>> { + futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + futures_io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl<T> futures_io::AsyncBufRead for Compat<T> +where + T: tokio::io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&'a [u8]>> { + tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + tokio::io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl<T> tokio::io::AsyncWrite for Compat<T> +where + T: futures_io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + futures_io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + futures_io::AsyncWrite::poll_close(self.project().inner, cx) + } +} + +impl<T> futures_io::AsyncWrite for Compat<T> +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } +} + +impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> { + fn poll_seek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: io::SeekFrom, + ) -> Poll<io::Result<u64>> { + if self.seek_pos != Some(pos) { + self.as_mut().project().inner.start_seek(pos)?; + *self.as_mut().project().seek_pos = Some(pos); + } + let res = ready!(self.as_mut().project().inner.poll_complete(cx)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> { + fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { + *self.as_mut().project().seek_pos = Some(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let pos = match self.seek_pos { + None => { + // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + Some(pos) => pos, + }; + let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +#[cfg(unix)] +impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(windows)] +impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + self.inner.as_raw_handle() + } +} diff --git a/third_party/rust/tokio-util/src/context.rs b/third_party/rust/tokio-util/src/context.rs new file mode 100644 index 0000000000..a7a5e02949 --- /dev/null +++ b/third_party/rust/tokio-util/src/context.rs @@ -0,0 +1,190 @@ +//! Tokio context aware futures utilities. +//! +//! This module includes utilities around integrating tokio with other runtimes +//! by allowing the context to be attached to futures. This allows spawning +//! futures on other executors while still using tokio to drive them. This +//! can be useful if you need to use a tokio based library in an executor/runtime +//! that does not provide a tokio context. + +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::runtime::{Handle, Runtime}; + +pin_project! { + /// `TokioContext` allows running futures that must be inside Tokio's + /// context on a non-Tokio runtime. + /// + /// It contains a [`Handle`] to the runtime. A handle to the runtime can be + /// obtain by calling the [`Runtime::handle()`] method. + /// + /// Note that the `TokioContext` wrapper only works if the `Runtime` it is + /// connected to has not yet been destroyed. You must keep the `Runtime` + /// alive until the future has finished executing. + /// + /// **Warning:** If `TokioContext` is used together with a [current thread] + /// runtime, that runtime must be inside a call to `block_on` for the + /// wrapped future to work. For this reason, it is recommended to use a + /// [multi thread] runtime, even if you configure it to only spawn one + /// worker thread. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`Handle`]: struct@tokio::runtime::Handle + /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle + /// [`RuntimeExt`]: trait@crate::context::RuntimeExt + /// [`new_static`]: fn@Self::new_static + /// [`sleep`]: fn@tokio::time::sleep + /// [current thread]: fn@tokio::runtime::Builder::new_current_thread + /// [enables time]: fn@tokio::runtime::Builder::enable_time + /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread + pub struct TokioContext<F> { + #[pin] + inner: F, + handle: Handle, + } +} + +impl<F> TokioContext<F> { + /// Associate the provided future with the context of the runtime behind + /// the provided `Handle`. + /// + /// This constructor uses a `'static` lifetime to opt-out of checking that + /// the runtime still exists. + /// + /// # Examples + /// + /// This is the same as the example above, but uses the `new` constructor + /// rather than [`RuntimeExt::wrap`]. + /// + /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::TokioContext; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// let fut = TokioContext::new( + /// async { sleep(Duration::from_millis(2)).await }, + /// rt.handle().clone(), + /// ); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + pub fn new(future: F, handle: Handle) -> TokioContext<F> { + TokioContext { + inner: future, + handle, + } + } + + /// Obtain a reference to the handle inside this `TokioContext`. + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Remove the association between the Tokio runtime and the wrapped future. + pub fn into_inner(self) -> F { + self.inner + } +} + +impl<F: Future> Future for TokioContext<F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + let handle = me.handle; + let fut = me.inner; + + let _enter = handle.enter(); + fut.poll(cx) + } +} + +/// Extension trait that simplifies bundling a `Handle` with a `Future`. +pub trait RuntimeExt { + /// Create a [`TokioContext`] that wraps the provided future and runs it in + /// this runtime's context. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`TokioContext`]: struct@crate::context::TokioContext + /// [`sleep`]: fn@tokio::time::sleep + /// [enables time]: fn@tokio::runtime::Builder::enable_time + fn wrap<F: Future>(&self, fut: F) -> TokioContext<F>; +} + +impl RuntimeExt for Runtime { + fn wrap<F: Future>(&self, fut: F) -> TokioContext<F> { + TokioContext { + inner: fut, + handle: self.handle().clone(), + } + } +} diff --git a/third_party/rust/tokio-util/src/either.rs b/third_party/rust/tokio-util/src/either.rs new file mode 100644 index 0000000000..9225e53ca6 --- /dev/null +++ b/third_party/rust/tokio-util/src/either.rs @@ -0,0 +1,188 @@ +//! Module defining an Either type. +use std::{ + future::Future, + io::SeekFrom, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; + +/// Combines two different futures, streams, or sinks having the same associated types into a single type. +/// +/// This type implements common asynchronous traits such as [`Future`] and those in Tokio. +/// +/// [`Future`]: std::future::Future +/// +/// # Example +/// +/// The following code will not work: +/// +/// ```compile_fail +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// some_async_function() +/// } else { +/// other_async_function() // <- Will print: "`if` and `else` have incompatible types" +/// }; +/// +/// println!("Result is {}", result.await); +/// } +/// ``` +/// +// This is because although the output types for both futures is the same, the exact future +// types are different, but the compiler must be able to choose a single type for the +// `result` variable. +/// +/// When the output type is the same, we can wrap each future in `Either` to avoid the +/// issue: +/// +/// ``` +/// use tokio_util::either::Either; +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// Either::Left(some_async_function()) +/// } else { +/// Either::Right(other_async_function()) +/// }; +/// +/// let value = result.await; +/// println!("Result is {}", value); +/// # assert_eq!(value, 10); +/// } +/// ``` +#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. +#[derive(Debug, Clone)] +pub enum Either<L, R> { + Left(L), + Right(R), +} + +/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. +/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either +/// enum variant held in `self`. +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), + Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), + } + } + } +} + +impl<L, R, O> Future for Either<L, R> +where + L: Future<Output = O>, + R: Future<Output = O>, +{ + type Output = O; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + delegate_call!(self.poll(cx)) + } +} + +impl<L, R> AsyncRead for Either<L, R> +where + L: AsyncRead, + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<()>> { + delegate_call!(self.poll_read(cx, buf)) + } +} + +impl<L, R> AsyncBufRead for Either<L, R> +where + L: AsyncBufRead, + R: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> { + delegate_call!(self.poll_fill_buf(cx)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + delegate_call!(self.consume(amt)) + } +} + +impl<L, R> AsyncSeek for Either<L, R> +where + L: AsyncSeek, + R: AsyncSeek, +{ + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { + delegate_call!(self.start_seek(position)) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> { + delegate_call!(self.poll_complete(cx)) + } +} + +impl<L, R> AsyncWrite for Either<L, R> +where + L: AsyncWrite, + R: AsyncWrite, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { + delegate_call!(self.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { + delegate_call!(self.poll_shutdown(cx)) + } +} + +impl<L, R> futures_core::stream::Stream for Either<L, R> +where + L: futures_core::stream::Stream, + R: futures_core::stream::Stream<Item = L::Item>, +{ + type Item = L::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + delegate_call!(self.poll_next(cx)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{repeat, AsyncReadExt, Repeat}; + use tokio_stream::{once, Once, StreamExt}; + + #[tokio::test] + async fn either_is_stream() { + let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1)); + + assert_eq!(Some(1u32), either.next().await); + } + + #[tokio::test] + async fn either_is_async_read() { + let mut buffer = [0; 3]; + let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101)); + + either.read_exact(&mut buffer).await.unwrap(); + assert_eq!(buffer, [0b101, 0b101, 0b101]); + } +} diff --git a/third_party/rust/tokio-util/src/io/mod.rs b/third_party/rust/tokio-util/src/io/mod.rs new file mode 100644 index 0000000000..eb48a21fb9 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/mod.rs @@ -0,0 +1,24 @@ +//! Helpers for IO related tasks. +//! +//! The stream types are often used in combination with hyper or reqwest, as they +//! allow converting between a hyper [`Body`] and [`AsyncRead`]. +//! +//! The [`SyncIoBridge`] type converts from the world of async I/O +//! to synchronous I/O; this may often come up when using synchronous APIs +//! inside [`tokio::task::spawn_blocking`]. +//! +//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html +//! [`AsyncRead`]: tokio::io::AsyncRead + +mod read_buf; +mod reader_stream; +mod stream_reader; +cfg_io_util! { + mod sync_bridge; + pub use self::sync_bridge::SyncIoBridge; +} + +pub use self::read_buf::read_buf; +pub use self::reader_stream::ReaderStream; +pub use self::stream_reader::StreamReader; +pub use crate::util::{poll_read_buf, poll_write_buf}; diff --git a/third_party/rust/tokio-util/src/io/read_buf.rs b/third_party/rust/tokio-util/src/io/read_buf.rs new file mode 100644 index 0000000000..d7938a3bc1 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/read_buf.rs @@ -0,0 +1,65 @@ +use bytes::BufMut; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. +/// +/// [`BufMut`]: bytes::BufMut +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio_stream as stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, read_buf}; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = read_buf(&mut read, &mut buf).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +pub async fn read_buf<R, B>(read: &mut R, buf: &mut B) -> io::Result<usize> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + return ReadBufFn(read, buf).await; + + struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); + + impl<'a, R, B> Future for ReadBufFn<'a, R, B> + where + R: AsyncRead + Unpin, + B: BufMut, + { + type Output = io::Result<usize>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = &mut *self; + crate::util::poll_read_buf(Pin::new(this.0), cx, this.1) + } + } +} diff --git a/third_party/rust/tokio-util/src/io/reader_stream.rs b/third_party/rust/tokio-util/src/io/reader_stream.rs new file mode 100644 index 0000000000..866c11408d --- /dev/null +++ b/third_party/rust/tokio-util/src/io/reader_stream.rs @@ -0,0 +1,118 @@ +use bytes::{Bytes, BytesMut}; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +const DEFAULT_CAPACITY: usize = 4096; + +pin_project! { + /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. + /// + /// This stream is fused. It performs the inverse operation of + /// [`StreamReader`]. + /// + /// # Example + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// use tokio_stream::StreamExt; + /// use tokio_util::io::ReaderStream; + /// + /// // Create a stream of data. + /// let data = b"hello, world!"; + /// let mut stream = ReaderStream::new(&data[..]); + /// + /// // Read all of the chunks into a vector. + /// let mut stream_contents = Vec::new(); + /// while let Some(chunk) = stream.next().await { + /// stream_contents.extend_from_slice(&chunk?); + /// } + /// + /// // Once the chunks are concatenated, we should have the + /// // original data. + /// assert_eq!(stream_contents, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`StreamReader`]: crate::io::StreamReader + /// [`Stream`]: futures_core::Stream + #[derive(Debug)] + pub struct ReaderStream<R> { + // Reader itself. + // + // This value is `None` if the stream has terminated. + #[pin] + reader: Option<R>, + // Working buffer, used to optimize allocations. + buf: BytesMut, + capacity: usize, + } +} + +impl<R: AsyncRead> ReaderStream<R> { + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result<Bytes, std::io::Error>`. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn new(reader: R) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::new(), + capacity: DEFAULT_CAPACITY, + } + } + + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result<Bytes, std::io::Error>`, + /// with a specific read buffer initial capacity. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn with_capacity(reader: R, capacity: usize) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::with_capacity(capacity), + capacity, + } + } +} + +impl<R: AsyncRead> Stream for ReaderStream<R> { + type Item = std::io::Result<Bytes>; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + use crate::util::poll_read_buf; + + let mut this = self.as_mut().project(); + + let reader = match this.reader.as_pin_mut() { + Some(r) => r, + None => return Poll::Ready(None), + }; + + if this.buf.capacity() == 0 { + this.buf.reserve(*this.capacity); + } + + match poll_read_buf(reader, cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + self.project().reader.set(None); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + self.project().reader.set(None); + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } +} diff --git a/third_party/rust/tokio-util/src/io/stream_reader.rs b/third_party/rust/tokio-util/src/io/stream_reader.rs new file mode 100644 index 0000000000..05ae886557 --- /dev/null +++ b/third_party/rust/tokio-util/src/io/stream_reader.rs @@ -0,0 +1,203 @@ +use bytes::Buf; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +pin_project! { + /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. + /// + /// This type performs the inverse operation of [`ReaderStream`]. + /// + /// # Example + /// + /// ``` + /// use bytes::Bytes; + /// use tokio::io::{AsyncReadExt, Result}; + /// use tokio_util::io::StreamReader; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a stream from an iterator. + /// let stream = tokio_stream::iter(vec![ + /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), + /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), + /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), + /// ]); + /// + /// // Convert it to an AsyncRead. + /// let mut read = StreamReader::new(stream); + /// + /// // Read five bytes from the stream. + /// let mut buf = [0; 5]; + /// read.read_exact(&mut buf).await?; + /// assert_eq!(buf, [0, 1, 2, 3, 4]); + /// + /// // Read the rest of the current chunk. + /// assert_eq!(read.read(&mut buf).await?, 3); + /// assert_eq!(&buf[..3], [5, 6, 7]); + /// + /// // Read the next chunk. + /// assert_eq!(read.read(&mut buf).await?, 4); + /// assert_eq!(&buf[..4], [8, 9, 10, 11]); + /// + /// // We have now reached the end. + /// assert_eq!(read.read(&mut buf).await?, 0); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + /// [`ReaderStream`]: crate::io::ReaderStream + #[derive(Debug)] + pub struct StreamReader<S, B> { + #[pin] + inner: S, + chunk: Option<B>, + } +} + +impl<S, B, E> StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). + /// + /// The item should be a [`Result`] with the ok variant being something that + /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error + /// should be convertible into an [io error]. + /// + /// [`Result`]: std::result::Result + /// [`Buf`]: bytes::Buf + /// [io error]: std::io::Error + pub fn new(stream: S) -> Self { + Self { + inner: stream, + chunk: None, + } + } + + /// Do we have a chunk and is it non-empty? + fn has_chunk(&self) -> bool { + if let Some(ref chunk) = self.chunk { + chunk.remaining() > 0 + } else { + false + } + } + + /// Consumes this `StreamReader`, returning a Tuple consisting + /// of the underlying stream and an Option of the interal buffer, + /// which is Some in case the buffer contains elements. + pub fn into_inner_with_chunk(self) -> (S, Option<B>) { + if self.has_chunk() { + (self.inner, self.chunk) + } else { + (self.inner, None) + } + } +} + +impl<S, B> StreamReader<S, B> { + /// Gets a reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + /// Consumes this `BufWriter`, returning the underlying stream. + /// + /// Note that any leftover data in the internal buffer is lost. + /// If you additionally want access to the internal buffer use + /// [`into_inner_with_chunk`]. + /// + /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk + pub fn into_inner(self) -> S { + self.inner + } +} + +impl<S, B, E> AsyncRead for StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(())) + } +} + +impl<S, B, E> AsyncBufRead for StreamReader<S, B> +where + S: Stream<Item = Result<B, E>>, + B: Buf, + E: Into<std::io::Error>, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + loop { + if self.as_mut().has_chunk() { + // This unwrap is very sad, but it can't be avoided. + let buf = self.project().chunk.as_ref().unwrap().chunk(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + // Go around the loop in case the chunk is empty. + *self.as_mut().project().chunk = Some(chunk); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} diff --git a/third_party/rust/tokio-util/src/io/sync_bridge.rs b/third_party/rust/tokio-util/src/io/sync_bridge.rs new file mode 100644 index 0000000000..9be9446a7d --- /dev/null +++ b/third_party/rust/tokio-util/src/io/sync_bridge.rs @@ -0,0 +1,103 @@ +use std::io::{Read, Write}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or +/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. +#[derive(Debug)] +pub struct SyncIoBridge<T> { + src: T, + rt: tokio::runtime::Handle, +} + +impl<T: AsyncRead + Unpin> Read for SyncIoBridge<T> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(AsyncReadExt::read(src, buf)) + } + + fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.read_to_end(buf)) + } + + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.read_to_string(buf)) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let src = &mut self.src; + // The AsyncRead trait returns the count, synchronous doesn't. + let _n = self.rt.block_on(src.read_exact(buf))?; + Ok(()) + } +} + +impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.write(buf)) + } + + fn flush(&mut self) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.flush()) + } + + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.write_all(buf)) + } + + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> { + let src = &mut self.src; + self.rt.block_on(src.write_vectored(bufs)) + } +} + +// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time +// of this writing still unstable, we expose this as part of a standalone method. +impl<T: AsyncWrite> SyncIoBridge<T> { + /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes. + /// + /// See [`tokio::io::AsyncWrite::is_write_vectored`]. + pub fn is_write_vectored(&self) -> bool { + self.src.is_write_vectored() + } +} + +impl<T: Unpin> SyncIoBridge<T> { + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`]. + /// It is hence OK to move this struct into a separate thread outside the runtime, as created + /// by e.g. [`tokio::task::spawn_blocking`]. + /// + /// Stated even more strongly: to make use of this bridge, you *must* move + /// it into a separate thread outside the runtime. The synchronous I/O will use the + /// underlying handle to block on the backing asynchronous source, via + /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that + /// function, an attempt to `block_on` from an asynchronous execution context + /// will panic. + /// + /// # Wrapping `!Unpin` types + /// + /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. + /// + /// # Panic + /// + /// This will panic if called outside the context of a Tokio runtime. + pub fn new(src: T) -> Self { + Self::new_with_handle(src, tokio::runtime::Handle::current()) + } + + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may + /// be initially invoked outside of an asynchronous context. + pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { + Self { src, rt } + } +} diff --git a/third_party/rust/tokio-util/src/lib.rs b/third_party/rust/tokio-util/src/lib.rs new file mode 100644 index 0000000000..fd14a8ac94 --- /dev/null +++ b/third_party/rust/tokio-util/src/lib.rs @@ -0,0 +1,201 @@ +#![allow(clippy::needless_doctest_main)] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Utilities for working with Tokio. +//! +//! This crate is not versioned in lockstep with the core +//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's +//! semantic versioning policy, especially with regard to breaking changes. +//! +//! [`tokio`]: https://docs.rs/tokio + +#[macro_use] +mod cfg; + +mod loom; + +cfg_codec! { + pub mod codec; +} + +cfg_net! { + pub mod udp; + pub mod net; +} + +cfg_compat! { + pub mod compat; +} + +cfg_io! { + pub mod io; +} + +cfg_rt! { + pub mod context; + pub mod task; +} + +cfg_time! { + pub mod time; +} + +pub mod sync; + +pub mod either; + +#[cfg(any(feature = "io", feature = "codec"))] +mod util { + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + + use bytes::{Buf, BufMut}; + use futures_core::ready; + use std::io::{self, IoSlice}; + use std::mem::MaybeUninit; + use std::pin::Pin; + use std::task::{Context, Poll}; + + /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. + /// + /// [`BufMut`]: bytes::Buf + /// + /// # Example + /// + /// ``` + /// use bytes::{Bytes, BytesMut}; + /// use tokio_stream as stream; + /// use tokio::io::Result; + /// use tokio_util::io::{StreamReader, poll_read_buf}; + /// use futures::future::poll_fn; + /// use std::pin::Pin; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a reader from an iterator. This particular reader will always be + /// // ready. + /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); + /// + /// let mut buf = BytesMut::new(); + /// let mut reads = 0; + /// + /// loop { + /// reads += 1; + /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; + /// + /// if n == 0 { + /// break; + /// } + /// } + /// + /// // one or more reads might be necessary. + /// assert!(reads >= 1); + /// assert_eq!(&buf[..], &[0, 1, 2, 3]); + /// # Ok(()) + /// # } + /// ``` + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_read_buf<T: AsyncRead, B: BufMut>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) + } + + /// Try to write data from an implementer of the [`Buf`] trait to an + /// [`AsyncWrite`], advancing the buffer's internal cursor. + /// + /// This function will use [vectored writes] when the [`AsyncWrite`] supports + /// vectored writes. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements + /// [`Buf`]: + /// + /// ```no_run + /// use tokio_util::io::poll_write_buf; + /// use tokio::io; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// use std::pin::Pin; + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buf = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buf.has_remaining() { + /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`Buf`]: bytes::Buf + /// [`AsyncWrite`]: tokio::io::AsyncWrite + /// [`File`]: tokio::fs::File + /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_write_buf<T: AsyncWrite, B: Buf>( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.chunks_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.chunk()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) + } +} diff --git a/third_party/rust/tokio-util/src/loom.rs b/third_party/rust/tokio-util/src/loom.rs new file mode 100644 index 0000000000..dd03feaba1 --- /dev/null +++ b/third_party/rust/tokio-util/src/loom.rs @@ -0,0 +1 @@ +pub(crate) use std::sync; diff --git a/third_party/rust/tokio-util/src/net/mod.rs b/third_party/rust/tokio-util/src/net/mod.rs new file mode 100644 index 0000000000..4817e10d0f --- /dev/null +++ b/third_party/rust/tokio-util/src/net/mod.rs @@ -0,0 +1,97 @@ +//! TCP/UDP/Unix helpers for tokio. + +use crate::either::Either; +use std::future::Future; +use std::io::Result; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[cfg(unix)] +pub mod unix; + +/// A trait for a listener: `TcpListener` and `UnixListener`. +pub trait Listener { + /// The stream's type of this listener. + type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite; + /// The socket address type of this listener. + type Addr; + + /// Polls to accept a new incoming connection to this listener. + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>>; + + /// Accepts a new incoming connection from this listener. + fn accept(&mut self) -> ListenerAcceptFut<'_, Self> + where + Self: Sized, + { + ListenerAcceptFut { listener: self } + } + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> Result<Self::Addr>; +} + +impl Listener for tokio::net::TcpListener { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result<Self::Addr> { + self.local_addr().map(Into::into) + } +} + +/// Future for accepting a new connection from a listener. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ListenerAcceptFut<'a, L> { + listener: &'a mut L, +} + +impl<'a, L> Future for ListenerAcceptFut<'a, L> +where + L: Listener, +{ + type Output = Result<(L::Io, L::Addr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.listener.poll_accept(cx) + } +} + +impl<L, R> Either<L, R> +where + L: Listener, + R: Listener, +{ + /// Accepts a new incoming connection from this listener. + pub async fn accept(&mut self) -> Result<Either<(L::Io, L::Addr), (R::Io, R::Addr)>> { + match self { + Either::Left(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Left((stream, addr))) + } + Either::Right(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Right((stream, addr))) + } + } + } + + /// Returns the local address that this listener is bound to. + pub fn local_addr(&self) -> Result<Either<L::Addr, R::Addr>> { + match self { + Either::Left(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Left(addr)) + } + Either::Right(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Right(addr)) + } + } + } +} diff --git a/third_party/rust/tokio-util/src/net/unix/mod.rs b/third_party/rust/tokio-util/src/net/unix/mod.rs new file mode 100644 index 0000000000..0b522c90a3 --- /dev/null +++ b/third_party/rust/tokio-util/src/net/unix/mod.rs @@ -0,0 +1,18 @@ +//! Unix domain socket helpers. + +use super::Listener; +use std::io::Result; +use std::task::{Context, Poll}; + +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result<Self::Addr> { + self.local_addr().map(Into::into) + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token.rs b/third_party/rust/tokio-util/src/sync/cancellation_token.rs new file mode 100644 index 0000000000..2a6ef392bd --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token.rs @@ -0,0 +1,224 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. +pub(crate) mod guard; +mod tree_node; + +use crate::loom::sync::Arc; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use guard::DropGuard; +use pin_project_lite::pin_project; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::select; +/// use tokio_util::sync::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::sleep(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: Arc<tree_node::TreeNode>, +} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFuture<'a> { + cancellation_token: &'a CancellationToken, + #[pin] + future: tokio::sync::futures::Notified<'a>, + } +} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + tree_node::increase_handle_refcount(&self.inner); + CancellationToken { + inner: self.inner.clone(), + } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + tree_node::decrease_handle_refcount(&self.inner); + } +} + +impl Default for CancellationToken { + fn default() -> CancellationToken { + CancellationToken::new() + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + CancellationToken { + inner: Arc::new(tree_node::TreeNode::new()), + } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::select; + /// use tokio_util::sync::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: tree_node::child_node(&self.inner), + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + /// + /// Be aware that cancellation is not an atomic operation. It is possible + /// for another thread running in parallel with a call to `cancel` to first + /// receive `true` from `is_cancelled` on one child node, and then receive + /// `false` from `is_cancelled` on another child node. However, once the + /// call to `cancel` returns, all child nodes have been fully cancelled. + pub fn cancel(&self) { + tree_node::cancel(&self.inner); + } + + /// Returns `true` if the `CancellationToken` is cancelled. + pub fn is_cancelled(&self) -> bool { + tree_node::is_cancelled(&self.inner) + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: self, + future: self.inner.notified(), + } + } + + /// Creates a `DropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> DropGuard { + DropGuard { inner: Some(self) } + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + this.future.set(this.cancellation_token.inner.notified()); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs new file mode 100644 index 0000000000..54ed7ea2ed --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/guard.rs @@ -0,0 +1,27 @@ +use crate::sync::CancellationToken; + +/// A wrapper for cancellation token which automatically cancels +/// it on drop. It is created using `drop_guard` method on the `CancellationToken`. +#[derive(Debug)] +pub struct DropGuard { + pub(super) inner: Option<CancellationToken>, +} + +impl DropGuard { + /// Returns stored cancellation token and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(mut self) -> CancellationToken { + self.inner + .take() + .expect("`inner` can be only None in a destructor") + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(inner) = &self.inner { + inner.cancel(); + } + } +} diff --git a/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs new file mode 100644 index 0000000000..b6cd698e23 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/cancellation_token/tree_node.rs @@ -0,0 +1,373 @@ +//! This mod provides the logic for the inner tree structure of the CancellationToken. +//! +//! CancellationTokens are only light handles with references to TreeNode. +//! All the logic is actually implemented in the TreeNode. +//! +//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of +//! children. +//! +//! A TreeNode can receive the request to perform a cancellation through a CancellationToken. +//! This cancellation request will cancel the node and all of its descendants. +//! +//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no +//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the +//! tree as small as possible. +//! +//! # Invariants +//! +//! Those invariants shall be true at any time. +//! +//! 1. A node that has no parents and no handles can no longer be cancelled. +//! This is important during both cancellation and refcounting. +//! +//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. +//! This is important for deadlock safety, as it is used for lock order. +//! Node B can only become the child of node A in two ways: +//! - being created with `child_node()`, in which case it is trivially true that +//! node A already existed when node B was created +//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` +//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C +//! was younger than A, therefore B is also younger than A. +//! +//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of +//! node A. It is important to always restore that invariant before dropping the lock of a node. +//! +//! # Deadlock safety +//! +//! We always lock in the order of creation time. We can prove this through invariant #2. +//! Specifically, through invariant #2, we know that we always have to lock a parent +//! before its child. +//! +use crate::loom::sync::{Arc, Mutex, MutexGuard}; + +/// A node of the cancellation tree structure +/// +/// The actual data it holds is wrapped inside a mutex for synchronization. +pub(crate) struct TreeNode { + inner: Mutex<Inner>, + waker: tokio::sync::Notify, +} +impl TreeNode { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + } + } + + pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { + self.waker.notified() + } +} + +/// The data contained inside a TreeNode. +/// +/// This struct exists so that the data of the node can be wrapped +/// in a Mutex. +struct Inner { + parent: Option<Arc<TreeNode>>, + parent_idx: usize, + children: Vec<Arc<TreeNode>>, + is_cancelled: bool, + num_handles: usize, +} + +/// Returns whether or not the node is cancelled +pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool { + node.inner.lock().unwrap().is_cancelled +} + +/// Creates a child node +pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> { + let mut locked_parent = parent.inner.lock().unwrap(); + + // Do not register as child if we are already cancelled. + // Cancelled trees can never be uncancelled and therefore + // need no connection to parents or children any more. + if locked_parent.is_cancelled { + return Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: true, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + } + + let child = Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: Some(parent.clone()), + parent_idx: locked_parent.children.len(), + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + + locked_parent.children.push(child.clone()); + + child +} + +/// Disconnects the given parent from all of its children. +/// +/// Takes a reference to [Inner] to make sure the parent is already locked. +fn disconnect_children(node: &mut Inner) { + for child in std::mem::take(&mut node.children) { + let mut locked_child = child.inner.lock().unwrap(); + locked_child.parent_idx = 0; + locked_child.parent = None; + } +} + +/// Figures out the parent of the node and locks the node and its parent atomically. +/// +/// The basic principle of preventing deadlocks in the tree is +/// that we always lock the parent first, and then the child. +/// For more info look at *deadlock safety* and *invariant #2*. +/// +/// Sadly, it's impossible to figure out the parent of a node without +/// locking it. To then achieve locking order consistency, the node +/// has to be unlocked before the parent gets locked. +/// This leaves a small window where we already assume that we know the parent, +/// but neither the parent nor the node is locked. Therefore, the parent could change. +/// +/// To prevent that this problem leaks into the rest of the code, it is abstracted +/// in this function. +/// +/// The locked child and optionally its locked parent, if a parent exists, get passed +/// to the `func` argument via (node, None) or (node, Some(parent)). +fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret +where + F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret, +{ + let mut potential_parent = { + let locked_node = node.inner.lock().unwrap(); + match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => return func(locked_node, None), + } + }; + + loop { + // Deadlock safety: + // + // Due to invariant #2, we know that we have to lock the parent first, and then the child. + // This is true even if the potential_parent is no longer the current parent or even its + // sibling, as the invariant still holds. + let locked_parent = potential_parent.inner.lock().unwrap(); + let locked_node = node.inner.lock().unwrap(); + + let actual_parent = match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => { + // Was the wrong parent, so unlock it before calling `func` + drop(locked_parent); + return func(locked_node, None); + } + }; + + // Loop until we managed to lock both the node and its parent + if Arc::ptr_eq(&actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } + + // Drop locked_parent before reassigning to potential_parent, + // as potential_parent is borrowed in it + drop(locked_node); + drop(locked_parent); + + potential_parent = actual_parent; + } +} + +/// Moves all children from `node` to `parent`. +/// +/// `parent` MUST have been a parent of the node when they both got locked, +/// otherwise there is a potential for a deadlock as invariant #2 would be violated. +/// +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { + // Pre-allocate in the parent, for performance + parent.children.reserve(node.children.len()); + + for child in std::mem::take(&mut node.children) { + { + let mut child_locked = child.inner.lock().unwrap(); + child_locked.parent = node.parent.clone(); + child_locked.parent_idx = parent.children.len(); + } + parent.children.push(child); + } +} + +/// Removes a child from the parent. +/// +/// `parent` MUST be the parent of `node`. +/// To aquire the locks for node and parent, use [with_locked_node_and_parent]. +fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { + // Query the position from where to remove a node + let pos = node.parent_idx; + node.parent = None; + node.parent_idx = 0; + + // Unlock node, so that only one child at a time is locked. + // Otherwise we would violate the lock order (see 'deadlock safety') as we + // don't know the creation order of the child nodes + drop(node); + + // If `node` is the last element in the list, we don't need any swapping + if parent.children.len() == pos + 1 { + parent.children.pop().unwrap(); + } else { + // If `node` is not the last element in the list, we need to + // replace it with the last element + let replacement_child = parent.children.pop().unwrap(); + replacement_child.inner.lock().unwrap().parent_idx = pos; + parent.children[pos] = replacement_child; + } + + let len = parent.children.len(); + if 4 * len <= parent.children.capacity() { + // equal to: + // parent.children.shrink_to(2 * len); + // but shrink_to was not yet stabilized in our minimal compatible version + let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len)); + parent.children.extend(old_children); + } +} + +/// Increases the reference count of handles. +pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + // Once no handles are left over, the node gets detached from the tree. + // There should never be a new handle once all handles are dropped. + assert!(locked_node.num_handles > 0); + + locked_node.num_handles += 1; +} + +/// Decreases the reference count of handles. +/// +/// Once no handle is left, we can remove the node from the +/// tree and connect its parent directly to its children. +pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) { + let num_handles = { + let mut locked_node = node.inner.lock().unwrap(); + locked_node.num_handles -= 1; + locked_node.num_handles + }; + + if num_handles == 0 { + with_locked_node_and_parent(node, |mut node, parent| { + // Remove the node from the tree + match parent { + Some(mut parent) => { + // As we want to remove ourselves from the tree, + // we have to move the children to the parent, so that + // they still receive the cancellation event without us. + // Moving them does not violate invariant #1. + move_children_to_parent(&mut node, &mut parent); + + // Remove the node from the parent + remove_child(&mut parent, node); + } + None => { + // Due to invariant #1, we can assume that our + // children can no longer be cancelled through us. + // (as we now have neither a parent nor handles) + // Therefore we can disconnect them. + disconnect_children(&mut node); + } + } + }); + } +} + +/// Cancels a node and its children. +pub(crate) fn cancel(node: &Arc<TreeNode>) { + let mut locked_node = node.inner.lock().unwrap(); + + if locked_node.is_cancelled { + return; + } + + // One by one, adopt grandchildren and then cancel and detach the child + while let Some(child) = locked_node.children.pop() { + // This can't deadlock because the mutex we are already + // holding is the parent of child. + let mut locked_child = child.inner.lock().unwrap(); + + // Detach the child from node + // No need to modify node.children, as the child already got removed with `.pop` + locked_child.parent = None; + locked_child.parent_idx = 0; + + // If child is already cancelled, detaching is enough + if locked_child.is_cancelled { + continue; + } + + // Cancel or adopt grandchildren + while let Some(grandchild) = locked_child.children.pop() { + // This can't deadlock because the two mutexes we are already + // holding is the parent and grandparent of grandchild. + let mut locked_grandchild = grandchild.inner.lock().unwrap(); + + // Detach the grandchild + locked_grandchild.parent = None; + locked_grandchild.parent_idx = 0; + + // If grandchild is already cancelled, detaching is enough + if locked_grandchild.is_cancelled { + continue; + } + + // For performance reasons, only adopt grandchildren that have children. + // Otherwise, just cancel them right away, no need for another iteration. + if locked_grandchild.children.is_empty() { + // Cancel the grandchild + locked_grandchild.is_cancelled = true; + locked_grandchild.children = Vec::new(); + drop(locked_grandchild); + grandchild.waker.notify_waiters(); + } else { + // Otherwise, adopt grandchild + locked_grandchild.parent = Some(node.clone()); + locked_grandchild.parent_idx = locked_node.children.len(); + drop(locked_grandchild); + locked_node.children.push(grandchild); + } + } + + // Cancel the child + locked_child.is_cancelled = true; + locked_child.children = Vec::new(); + drop(locked_child); + child.waker.notify_waiters(); + + // Now the child is cancelled and detached and all its children are adopted. + // Just continue until all (including adopted) children are cancelled and detached. + } + + // Cancel the node itself. + locked_node.is_cancelled = true; + locked_node.children = Vec::new(); + drop(locked_node); + node.waker.notify_waiters(); +} diff --git a/third_party/rust/tokio-util/src/sync/mod.rs b/third_party/rust/tokio-util/src/sync/mod.rs new file mode 100644 index 0000000000..de392f0bb1 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mod.rs @@ -0,0 +1,13 @@ +//! Synchronization primitives + +mod cancellation_token; +pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancellationFuture}; + +mod mpsc; +pub use mpsc::{PollSendError, PollSender}; + +mod poll_semaphore; +pub use poll_semaphore::PollSemaphore; + +mod reusable_box; +pub use reusable_box::ReusableBoxFuture; diff --git a/third_party/rust/tokio-util/src/sync/mpsc.rs b/third_party/rust/tokio-util/src/sync/mpsc.rs new file mode 100644 index 0000000000..34a47c1891 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/mpsc.rs @@ -0,0 +1,283 @@ +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, mem}; +use tokio::sync::mpsc::OwnedPermit; +use tokio::sync::mpsc::Sender; + +use super::ReusableBoxFuture; + +/// Error returned by the `PollSender` when the channel is closed. +#[derive(Debug)] +pub struct PollSendError<T>(Option<T>); + +impl<T> PollSendError<T> { + /// Consumes the stored value, if any. + /// + /// If this error was encountered when calling `start_send`/`send_item`, this will be the item + /// that the caller attempted to send. Otherwise, it will be `None`. + pub fn into_inner(self) -> Option<T> { + self.0 + } +} + +impl<T> fmt::Display for PollSendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} + +#[derive(Debug)] +enum State<T> { + Idle(Sender<T>), + Acquiring, + ReadyToSend(OwnedPermit<T>), + Closed, +} + +/// A wrapper around [`mpsc::Sender`] that can be polled. +/// +/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender +#[derive(Debug)] +pub struct PollSender<T> { + sender: Option<Sender<T>>, + state: State<T>, + acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>, +} + +// Creates a future for acquiring a permit from the underlying channel. This is used to ensure +// there's capacity for a send to complete. +// +// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to +// ReusableBoxFuture has the same underlying type, and hence the same size and alignment. +async fn make_acquire_future<T>( + data: Option<Sender<T>>, +) -> Result<OwnedPermit<T>, PollSendError<T>> { + match data { + Some(sender) => sender + .reserve_owned() + .await + .map_err(|_| PollSendError(None)), + None => unreachable!("this future should not be pollable in this state"), + } +} + +impl<T: Send + 'static> PollSender<T> { + /// Creates a new `PollSender`. + pub fn new(sender: Sender<T>) -> Self { + Self { + sender: Some(sender.clone()), + state: State::Idle(sender), + acquire: ReusableBoxFuture::new(make_acquire_future(None)), + } + } + + fn take_state(&mut self) -> State<T> { + mem::replace(&mut self.state, State::Closed) + } + + /// Attempts to prepare the sender to receive a value. + /// + /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to + /// `send_item`. + /// + /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, + /// by reserving a slot in the channel for the item to be sent. If this method returns + /// `Poll::Pending`, the current task is registered to be notified (via + /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { + loop { + let (result, next_state) = match self.take_state() { + State::Idle(sender) => { + // Start trying to acquire a permit to reserve a slot for our send, and + // immediately loop back around to poll it the first time. + self.acquire.set(make_acquire_future(Some(sender))); + (None, State::Acquiring) + } + State::Acquiring => match self.acquire.poll(cx) { + // Channel has capacity. + Poll::Ready(Ok(permit)) => { + (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) + } + // Channel is closed. + Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), + // Channel doesn't have capacity yet, so we need to wait. + Poll::Pending => (Some(Poll::Pending), State::Acquiring), + }, + // We're closed, either by choice or because the underlying sender was closed. + s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), + // We're already ready to send an item. + s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), + }; + + self.state = next_state; + if let Some(result) = result { + return result; + } + } + } + + /// Sends an item to the channel. + /// + /// Before calling `send_item`, `poll_reserve` must be called with a successful return + /// value of `Poll::Ready(Ok(()))`. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + /// + /// # Panics + /// + /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method + /// will panic. + pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { + let (result, next_state) = match self.take_state() { + State::Idle(_) | State::Acquiring => { + panic!("`send_item` called without first calling `poll_reserve`") + } + // We have a permit to send our item, so go ahead, which gets us our sender back. + State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), + // We're closed, either by choice or because the underlying sender was closed. + State::Closed => (Err(PollSendError(Some(value))), State::Closed), + }; + + // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. + self.state = if self.sender.is_some() { + next_state + } else { + State::Closed + }; + result + } + + /// Checks whether this sender is been closed. + /// + /// The underlying channel that this sender was wrapping may still be open. + pub fn is_closed(&self) -> bool { + matches!(self.state, State::Closed) || self.sender.is_none() + } + + /// Gets a reference to the `Sender` of the underlying channel. + /// + /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender + /// was wrapping may still be open. + pub fn get_ref(&self) -> Option<&Sender<T>> { + self.sender.as_ref() + } + + /// Closes this sender. + /// + /// No more messages will be able to be sent from this sender, but the underlying channel will + /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. + /// + /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made + /// to `send_item` in order to consume the reserved slot. After that, no further sends will be + /// possible. If you do not intend to send another item, you can release the reserved slot back + /// to the underlying sender by calling [`abort_send`]. + /// + /// [`abort_send`]: crate::sync::PollSender::abort_send + /// [`Receiver`]: tokio::sync::mpsc::Receiver + pub fn close(&mut self) { + // Mark ourselves officially closed by dropping our main sender. + self.sender = None; + + // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly + // transition to the closed state. Otherwise, leave the existing permit in place for the + // caller if they want to complete the send. + match self.state { + State::Idle(_) => self.state = State::Closed, + State::Acquiring => { + self.acquire.set(make_acquire_future(None)); + self.state = State::Closed; + } + _ => {} + } + } + + /// Aborts the current in-progress send, if any. + /// + /// Returns `true` if a send was aborted. If the sender was closed prior to calling + /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be + /// ready to attempt another send. + pub fn abort_send(&mut self) -> bool { + // We may have been closed in the meantime, after a call to `poll_reserve` already + // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the + // closed state when we actually abort a send, rather than resetting ourselves back to idle. + + let (result, next_state) = match self.take_state() { + // We're currently trying to reserve a slot to send into. + State::Acquiring => { + // Replacing the future drops the in-flight one. + self.acquire.set(make_acquire_future(None)); + + // If we haven't closed yet, we have to clone our stored sender since we have no way + // to get it back from the acquire future we just dropped. + let state = match self.sender.clone() { + Some(sender) => State::Idle(sender), + None => State::Closed, + }; + (true, state) + } + // We got the permit. If we haven't closed yet, get the sender back. + State::ReadyToSend(permit) => { + let state = if self.sender.is_some() { + State::Idle(permit.release()) + } else { + State::Closed + }; + (true, state) + } + s => (false, s), + }; + + self.state = next_state; + result + } +} + +impl<T> Clone for PollSender<T> { + /// Clones this `PollSender`. + /// + /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. + fn clone(&self) -> PollSender<T> { + let (sender, state) = match self.sender.clone() { + Some(sender) => (Some(sender.clone()), State::Idle(sender)), + None => (None, State::Closed), + }; + + Self { + sender, + state, + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender<T>`. + acquire: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl<T: Send + 'static> Sink<T> for PollSender<T> { + type Error = PollSendError<T>; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).poll_reserve(cx) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::into_inner(self).send_item(item) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::into_inner(self).close(); + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio-util/src/sync/poll_semaphore.rs b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs new file mode 100644 index 0000000000..d0b1dedc27 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/poll_semaphore.rs @@ -0,0 +1,136 @@ +use futures_core::{ready, Stream}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +use super::ReusableBoxFuture; + +/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. +/// +/// [`Semaphore`]: tokio::sync::Semaphore +pub struct PollSemaphore { + semaphore: Arc<Semaphore>, + permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>, +} + +impl PollSemaphore { + /// Create a new `PollSemaphore`. + pub fn new(semaphore: Arc<Semaphore>) -> Self { + Self { + semaphore, + permit_fut: None, + } + } + + /// Closes the semaphore. + pub fn close(&self) { + self.semaphore.close() + } + + /// Obtain a clone of the inner semaphore. + pub fn clone_inner(&self) -> Arc<Semaphore> { + self.semaphore.clone() + } + + /// Get back the inner semaphore. + pub fn into_inner(self) -> Arc<Semaphore> { + self.semaphore + } + + /// Poll to acquire a permit from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when a permit becomes available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + let permit_future = match self.permit_fut.as_mut() { + Some(fut) => fut, + None => { + // avoid allocations completely if we can grab a permit immediately + match Arc::clone(&self.semaphore).try_acquire_owned() { + Ok(permit) => return Poll::Ready(Some(permit)), + Err(TryAcquireError::Closed) => return Poll::Ready(None), + Err(TryAcquireError::NoPermits) => {} + } + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + self.permit_fut + .get_or_insert(ReusableBoxFuture::new(next_fut)) + } + }; + + let result = ready!(permit_future.poll(cx)); + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + permit_future.set(next_fut); + + match result { + Ok(permit) => Poll::Ready(Some(permit)), + Err(_closed) => { + self.permit_fut = None; + Poll::Ready(None) + } + } + } + + /// Returns the current number of available permits. + /// + /// This is equivalent to the [`Semaphore::available_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function + /// will panic if the limit is exceeded. + /// + /// This is equivalent to the [`Semaphore::add_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits + pub fn add_permits(&self, n: usize) { + self.semaphore.add_permits(n); + } +} + +impl Stream for PollSemaphore { + type Item = OwnedSemaphorePermit; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { + Pin::into_inner(self).poll_acquire(cx) + } +} + +impl Clone for PollSemaphore { + fn clone(&self) -> PollSemaphore { + PollSemaphore::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollSemaphore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollSemaphore") + .field("semaphore", &self.semaphore) + .finish() + } +} + +impl AsRef<Semaphore> for PollSemaphore { + fn as_ref(&self) -> &Semaphore { + &*self.semaphore + } +} diff --git a/third_party/rust/tokio-util/src/sync/reusable_box.rs b/third_party/rust/tokio-util/src/sync/reusable_box.rs new file mode 100644 index 0000000000..3204207db7 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/reusable_box.rs @@ -0,0 +1,148 @@ +use std::alloc::Layout; +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::ptr::{self, NonNull}; +use std::task::{Context, Poll}; +use std::{fmt, panic}; + +/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub struct ReusableBoxFuture<'a, T> { + boxed: NonNull<dyn Future<Output = T> + Send + 'a>, +} + +impl<'a, T> ReusableBoxFuture<'a, T> { + /// Create a new `ReusableBoxFuture<T>` containing the provided future. + pub fn new<F>(future: F) -> Self + where + F: Future<Output = T> + Send + 'a, + { + let boxed: Box<dyn Future<Output = T> + Send + 'a> = Box::new(future); + + let boxed = NonNull::from(Box::leak(boxed)); + + Self { boxed } + } + + /// Replace the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub fn set<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replace the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub fn try_set<F>(&mut self, future: F) -> Result<(), F> + where + F: Future<Output = T> + Send + 'a, + { + // SAFETY: The pointer is not dangling. + let self_layout = { + let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() }; + Layout::for_value(dyn_future) + }; + + if Layout::new::<F>() == self_layout { + // SAFETY: We just checked that the layout of F is correct. + unsafe { + self.set_same_layout(future); + } + + Ok(()) + } else { + Err(future) + } + } + + /// Set the current future. + /// + /// # Safety + /// + /// This function requires that the layout of the provided future is the + /// same as `self.layout`. + unsafe fn set_same_layout<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'a, + { + // Drop the existing future, catching any panics. + let result = panic::catch_unwind(AssertUnwindSafe(|| { + ptr::drop_in_place(self.boxed.as_ptr()); + })); + + // Overwrite the future behind the pointer. This is safe because the + // allocation was allocated with the same size and alignment as the type F. + let self_ptr: *mut F = self.boxed.as_ptr() as *mut F; + ptr::write(self_ptr, future); + + // Update the vtable of self.boxed. The pointer is not null because we + // just got it from self.boxed, which is not null. + self.boxed = NonNull::new_unchecked(self_ptr); + + // If the old future's destructor panicked, resume unwinding. + match result { + Ok(()) => {} + Err(payload) => { + panic::resume_unwind(payload); + } + } + } + + /// Get a pinned reference to the underlying future. + pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> { + // SAFETY: The user of this box cannot move the box, and we do not move it + // either. + unsafe { Pin::new_unchecked(self.boxed.as_mut()) } + } + + /// Poll the future stored inside this box. + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> { + self.get_pin().poll(cx) + } +} + +impl<T> Future for ReusableBoxFuture<'_, T> { + type Output = T; + + /// Poll the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The future stored inside ReusableBoxFuture<'_, T> must be Send. +unsafe impl<T> Send for ReusableBoxFuture<'_, T> {} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {} + +// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type. +impl<T> Unpin for ReusableBoxFuture<'_, T> {} + +impl<T> Drop for ReusableBoxFuture<'_, T> { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.boxed.as_ptr())); + } + } +} + +impl<T> fmt::Debug for ReusableBoxFuture<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} diff --git a/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 0000000000..e9c9f3dd98 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,155 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_childs() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/third_party/rust/tokio-util/src/sync/tests/mod.rs b/third_party/rust/tokio-util/src/sync/tests/mod.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/third_party/rust/tokio-util/src/sync/tests/mod.rs @@ -0,0 +1 @@ + diff --git a/third_party/rust/tokio-util/src/task/mod.rs b/third_party/rust/tokio-util/src/task/mod.rs new file mode 100644 index 0000000000..5aa33df2dc --- /dev/null +++ b/third_party/rust/tokio-util/src/task/mod.rs @@ -0,0 +1,4 @@ +//! Extra utilities for spawning tasks + +mod spawn_pinned; +pub use spawn_pinned::LocalPoolHandle; diff --git a/third_party/rust/tokio-util/src/task/spawn_pinned.rs b/third_party/rust/tokio-util/src/task/spawn_pinned.rs new file mode 100644 index 0000000000..6f553e9d07 --- /dev/null +++ b/third_party/rust/tokio-util/src/task/spawn_pinned.rs @@ -0,0 +1,307 @@ +use futures_util::future::{AbortHandle, Abortable}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::runtime::Builder; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::{spawn_local, JoinHandle, LocalSet}; + +/// A handle to a local pool, used for spawning `!Send` tasks. +#[derive(Clone)] +pub struct LocalPoolHandle { + pool: Arc<LocalPool>, +} + +impl LocalPoolHandle { + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this + /// pool via [`LocalPoolHandle::spawn_pinned`]. + /// + /// # Panics + /// Panics if the pool size is less than one. + pub fn new(pool_size: usize) -> LocalPoolHandle { + assert!(pool_size > 0); + + let workers = (0..pool_size) + .map(|_| LocalWorkerHandle::new_worker()) + .collect(); + + let pool = Arc::new(LocalPool { workers }); + + LocalPoolHandle { pool } + } + + /// Spawn a task onto a worker thread and pin it there so it can't be moved + /// off of the thread. Note that the future is not [`Send`], but the + /// [`FnOnce`] which creates it is. + /// + /// # Examples + /// ``` + /// use std::rc::Rc; + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create the local pool + /// let pool = LocalPoolHandle::new(1); + /// + /// // Spawn a !Send future onto the pool and await it + /// let output = pool + /// .spawn_pinned(|| { + /// // Rc is !Send + !Sync + /// let local_data = Rc::new("test"); + /// + /// // This future holds an Rc, so it is !Send + /// async move { local_data.to_string() } + /// }) + /// .await + /// .unwrap(); + /// + /// assert_eq!(output, "test"); + /// } + /// ``` + pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool.spawn_pinned(create_task) + } +} + +impl Debug for LocalPoolHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("LocalPoolHandle") + } +} + +struct LocalPool { + workers: Vec<LocalWorkerHandle>, +} + +impl LocalPool { + /// Spawn a `?Send` future onto a worker + fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + + let (worker, job_guard) = self.find_and_incr_least_burdened_worker(); + let worker_spawner = worker.spawner.clone(); + + // Spawn a future onto the worker's runtime so we can immediately return + // a join handle. + worker.runtime_handle.spawn(async move { + // Move the job guard into the task + let _job_guard = job_guard; + + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + + // Inside the future we can't run spawn_local yet because we're not + // in the context of a LocalSet. We need to send create_task to the + // LocalSet task for spawning. + let spawn_task = Box::new(move || { + // Once we're in the LocalSet context we can call spawn_local + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); + + // Send the join handle back to the spawner. If sending fails, + // we assume the parent task was canceled, so cancel this task + // as well. + if let Err(join_handle) = sender.send(join_handle) { + join_handle.abort() + } + }); + + // Send the callback to the LocalSet task + if let Err(e) = worker_spawner.send(spawn_task) { + // Propagate the error as a panic in the join handle. + panic!("Failed to send job to worker: {}", e); + } + + // Wait for the task's join handle + let join_handle = match receiver.await { + Ok(handle) => handle, + Err(e) => { + // We sent the task successfully, but failed to get its + // join handle... We assume something happened to the worker + // and the task was not spawned. Propagate the error as a + // panic in the join handle. + panic!("Worker failed to send join handle: {}", e); + } + }; + + // Wait for the task to complete + let join_result = join_handle.await; + + match join_result { + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else if e.is_cancelled() { + // No one else should have the join handle, so this is + // unexpected. Forward this error as a panic in the join + // handle. + panic!("spawn_pinned task was canceled: {}", e); + } else { + // Something unknown happened (not a panic or + // cancellation). Forward this error as a panic in the + // join handle. + panic!("spawn_pinned task failed: {}", e); + } + } + } + }) + } + + /// Find the worker with the least number of tasks, increment its task + /// count, and return its handle. Make sure to actually spawn a task on + /// the worker so the task count is kept consistent with load. + /// + /// A job count guard is also returned to ensure the task count gets + /// decremented when the job is done. + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { + loop { + let (worker, task_count) = self + .workers + .iter() + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) + .min_by_key(|&(_, count)| count) + .expect("There must be more than one worker"); + + // Make sure the task count hasn't changed since when we choose this + // worker. Otherwise, restart the search. + if worker + .task_count + .compare_exchange( + task_count, + task_count + 1, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); + } + } + } +} + +/// Automatically decrements a worker's job count when a job finishes (when +/// this gets dropped). +struct JobCountGuard(Arc<AtomicUsize>); + +impl Drop for JobCountGuard { + fn drop(&mut self) { + // Decrement the job count + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); + debug_assert!(previous_value >= 1); + } +} + +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); + +impl Drop for AbortGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; + +struct LocalWorkerHandle { + runtime_handle: tokio::runtime::Handle, + spawner: UnboundedSender<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, +} + +impl LocalWorkerHandle { + /// Create a new worker for executing pinned tasks + fn new_worker() -> LocalWorkerHandle { + let (sender, receiver) = unbounded_channel(); + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to start a pinned worker thread runtime"); + let runtime_handle = runtime.handle().clone(); + let task_count = Arc::new(AtomicUsize::new(0)); + let task_count_clone = Arc::clone(&task_count); + + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + + LocalWorkerHandle { + runtime_handle, + spawner: sender, + task_count, + } + } + + fn run( + runtime: tokio::runtime::Runtime, + mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, + ) { + let local_set = LocalSet::new(); + local_set.block_on(&runtime, async { + while let Some(spawn_task) = task_receiver.recv().await { + // Calls spawn_local(future) + (spawn_task)(); + } + }); + + // If there are any tasks on the runtime associated with a LocalSet task + // that has already completed, but whose output has not yet been + // reported, let that task complete. + // + // Since the task_count is decremented when the runtime task exits, + // reading that counter lets us know if any such tasks completed during + // the call to `block_on`. + // + // Tasks on the LocalSet can't complete during this loop since they're + // stored on the LocalSet and we aren't accessing it. + let mut previous_task_count = task_count.load(Ordering::SeqCst); + loop { + // This call will also run tasks spawned on the runtime. + runtime.block_on(tokio::task::yield_now()); + let new_task_count = task_count.load(Ordering::SeqCst); + if new_task_count == previous_task_count { + break; + } else { + previous_task_count = new_task_count; + } + } + + // It's now no longer possible for a task on the runtime to be + // associated with a LocalSet task that has completed. Drop both the + // LocalSet and runtime to let tasks on the runtime be cancelled if and + // only if they are still on the LocalSet. + // + // Drop the LocalSet task first so that anyone awaiting the runtime + // JoinHandle will see the cancelled error after the LocalSet task + // destructor has completed. + drop(local_set); + drop(runtime); + } +} diff --git a/third_party/rust/tokio-util/src/time/delay_queue.rs b/third_party/rust/tokio-util/src/time/delay_queue.rs new file mode 100644 index 0000000000..a0c5e5c5b0 --- /dev/null +++ b/third_party/rust/tokio-util/src/time/delay_queue.rs @@ -0,0 +1,1221 @@ +//! A queue of delayed elements. +//! +//! See [`DelayQueue`] for more details. +//! +//! [`DelayQueue`]: struct@DelayQueue + +use crate::time::wheel::{self, Wheel}; + +use futures_core::ready; +use tokio::time::{sleep_until, Duration, Instant, Sleep}; + +use core::ops::{Index, IndexMut}; +use slab::Slab; +use std::cmp; +use std::collections::HashMap; +use std::convert::From; +use std::fmt; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{self, Poll, Waker}; + +/// A queue of delayed elements. +/// +/// Once an element is inserted into the `DelayQueue`, it is yielded once the +/// specified deadline has been reached. +/// +/// # Usage +/// +/// Elements are inserted into `DelayQueue` using the [`insert`] or +/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is +/// returned. The key is used to remove the entry or to change the deadline at +/// which it should be yielded back. +/// +/// Once delays have been configured, the `DelayQueue` is used via its +/// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its +/// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the +/// current task will be notified once the deadline has been reached. +/// +/// # `Stream` implementation +/// +/// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have +/// expired, no items are returned. In this case, `Poll::Pending` is returned and the +/// current task is registered to be notified once the next item's delay has +/// expired. +/// +/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` +/// returns `Poll::Ready(None)`. This indicates that the stream has reached an end. +/// However, if a new item is inserted *after*, `poll` will once again start +/// returning items or `Poll::Pending`. +/// +/// Items are returned ordered by their expirations. Items that are configured +/// to expire first will be returned first. There are no ordering guarantees +/// for items configured to expire at the same instant. Also note that delays are +/// rounded to the closest millisecond. +/// +/// # Implementation +/// +/// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally +/// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same +/// performance and scalability benefits. +/// +/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, +/// and allows reuse of the memory allocated for expired entires. +/// +/// Capacity can be checked using [`capacity`] and allocated preemptively by using +/// the [`reserve`] method. +/// +/// # Usage +/// +/// Using `DelayQueue` to manage cache entries. +/// +/// ```rust,no_run +/// use tokio_util::time::{DelayQueue, delay_queue}; +/// +/// use futures::ready; +/// use std::collections::HashMap; +/// use std::task::{Context, Poll}; +/// use std::time::Duration; +/// # type CacheKey = String; +/// # type Value = String; +/// +/// struct Cache { +/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>, +/// expirations: DelayQueue<CacheKey>, +/// } +/// +/// const TTL_SECS: u64 = 30; +/// +/// impl Cache { +/// fn insert(&mut self, key: CacheKey, value: Value) { +/// let delay = self.expirations +/// .insert(key.clone(), Duration::from_secs(TTL_SECS)); +/// +/// self.entries.insert(key, (value, delay)); +/// } +/// +/// fn get(&self, key: &CacheKey) -> Option<&Value> { +/// self.entries.get(key) +/// .map(|&(ref v, _)| v) +/// } +/// +/// fn remove(&mut self, key: &CacheKey) { +/// if let Some((_, cache_key)) = self.entries.remove(key) { +/// self.expirations.remove(&cache_key); +/// } +/// } +/// +/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> { +/// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) { +/// self.entries.remove(entry.get_ref()); +/// } +/// +/// Poll::Ready(()) +/// } +/// } +/// ``` +/// +/// [`insert`]: method@Self::insert +/// [`insert_at`]: method@Self::insert_at +/// [`Key`]: struct@Key +/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html +/// [`poll_expired`]: method@Self::poll_expired +/// [`Stream::poll_expired`]: method@Self::poll_expired +/// [`DelayQueue`]: struct@DelayQueue +/// [`sleep`]: fn@tokio::time::sleep +/// [`slab`]: slab +/// [`capacity`]: method@Self::capacity +/// [`reserve`]: method@Self::reserve +#[derive(Debug)] +pub struct DelayQueue<T> { + /// Stores data associated with entries + slab: SlabStorage<T>, + + /// Lookup structure tracking all delays in the queue + wheel: Wheel<Stack<T>>, + + /// Delays that were inserted when already expired. These cannot be stored + /// in the wheel + expired: Stack<T>, + + /// Delay expiring when the *first* item in the queue expires + delay: Option<Pin<Box<Sleep>>>, + + /// Wheel polling state + wheel_now: u64, + + /// Instant at which the timer starts + start: Instant, + + /// Waker that is invoked when we potentially need to reset the timer. + /// Because we lazily create the timer when the first entry is created, we + /// need to awaken any poller that polled us before that point. + waker: Option<Waker>, +} + +#[derive(Default)] +struct SlabStorage<T> { + inner: Slab<Data<T>>, + + // A `compact` call requires a re-mapping of the `Key`s that were changed + // during the `compact` call of the `slab`. Since the keys that were given out + // cannot be changed retroactively we need to keep track of these re-mappings. + // The keys of `key_map` correspond to the old keys that were given out and + // the values to the `Key`s that were re-mapped by the `compact` call. + key_map: HashMap<Key, KeyInternal>, + + // Index used to create new keys to hand out. + next_key_index: usize, + + // Whether `compact` has been called, necessary in order to decide whether + // to include keys in `key_map`. + compact_called: bool, +} + +impl<T> SlabStorage<T> { + pub(crate) fn with_capacity(capacity: usize) -> SlabStorage<T> { + SlabStorage { + inner: Slab::with_capacity(capacity), + key_map: HashMap::new(), + next_key_index: 0, + compact_called: false, + } + } + + // Inserts data into the inner slab and re-maps keys if necessary + pub(crate) fn insert(&mut self, val: Data<T>) -> Key { + let mut key = KeyInternal::new(self.inner.insert(val)); + let key_contained = self.key_map.contains_key(&key.into()); + + if key_contained { + // It's possible that a `compact` call creates capacitiy in `self.inner` in + // such a way that a `self.inner.insert` call creates a `key` which was + // previously given out during an `insert` call prior to the `compact` call. + // If `key` is contained in `self.key_map`, we have encountered this exact situation, + // We need to create a new key `key_to_give_out` and include the relation + // `key_to_give_out` -> `key` in `self.key_map`. + let key_to_give_out = self.create_new_key(); + assert!(!self.key_map.contains_key(&key_to_give_out.into())); + self.key_map.insert(key_to_give_out.into(), key); + key = key_to_give_out; + } else if self.compact_called { + // Include an identity mapping in `self.key_map` in order to allow us to + // panic if a key that was handed out is removed more than once. + self.key_map.insert(key.into(), key); + } + + key.into() + } + + // Re-map the key in case compact was previously called. + // Note: Since we include identity mappings in key_map after compact was called, + // we have information about all keys that were handed out. In the case in which + // compact was called and we try to remove a Key that was previously removed + // we can detect invalid keys if no key is found in `key_map`. This is necessary + // in order to prevent situations in which a previously removed key + // corresponds to a re-mapped key internally and which would then be incorrectly + // removed from the slab. + // + // Example to illuminate this problem: + // + // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we + // were to remove 1 again, we would not find it inside `key_map` anymore. + // If we were to imply from this that no re-mapping was necessary, we would + // incorrectly remove 1 from `self.slab.inner`, which corresponds to the + // handed-out key 2. + pub(crate) fn remove(&mut self, key: &Key) -> Data<T> { + let remapped_key = if self.compact_called { + match self.key_map.remove(key) { + Some(key_internal) => key_internal, + None => panic!("invalid key"), + } + } else { + (*key).into() + }; + + self.inner.remove(remapped_key.index) + } + + pub(crate) fn shrink_to_fit(&mut self) { + self.inner.shrink_to_fit(); + self.key_map.shrink_to_fit(); + } + + pub(crate) fn compact(&mut self) { + if !self.compact_called { + for (key, _) in self.inner.iter() { + self.key_map.insert(Key::new(key), KeyInternal::new(key)); + } + } + + let mut remapping = HashMap::new(); + self.inner.compact(|_, from, to| { + remapping.insert(from, to); + true + }); + + // At this point `key_map` contains a mapping for every element. + for internal_key in self.key_map.values_mut() { + if let Some(new_internal_key) = remapping.get(&internal_key.index) { + *internal_key = KeyInternal::new(*new_internal_key); + } + } + + if self.key_map.capacity() > 2 * self.key_map.len() { + self.key_map.shrink_to_fit(); + } + + self.compact_called = true; + } + + // Tries to re-map a `Key` that was given out to the user to its + // corresponding internal key. + fn remap_key(&self, key: &Key) -> Option<KeyInternal> { + let key_map = &self.key_map; + if self.compact_called { + key_map.get(&*key).copied() + } else { + Some((*key).into()) + } + } + + fn create_new_key(&mut self) -> KeyInternal { + while self.key_map.contains_key(&Key::new(self.next_key_index)) { + self.next_key_index = self.next_key_index.wrapping_add(1); + } + + KeyInternal::new(self.next_key_index) + } + + pub(crate) fn len(&self) -> usize { + self.inner.len() + } + + pub(crate) fn capacity(&self) -> usize { + self.inner.capacity() + } + + pub(crate) fn clear(&mut self) { + self.inner.clear(); + self.key_map.clear(); + self.compact_called = false; + } + + pub(crate) fn reserve(&mut self, additional: usize) { + self.inner.reserve(additional); + + if self.compact_called { + self.key_map.reserve(additional); + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub(crate) fn contains(&self, key: &Key) -> bool { + let remapped_key = self.remap_key(key); + + match remapped_key { + Some(internal_key) => self.inner.contains(internal_key.index), + None => false, + } + } +} + +impl<T> fmt::Debug for SlabStorage<T> +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + if fmt.alternate() { + fmt.debug_map().entries(self.inner.iter()).finish() + } else { + fmt.debug_struct("Slab") + .field("len", &self.len()) + .field("cap", &self.capacity()) + .finish() + } + } +} + +impl<T> Index<Key> for SlabStorage<T> { + type Output = Data<T>; + + fn index(&self, key: Key) -> &Self::Output { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +impl<T> IndexMut<Key> for SlabStorage<T> { + fn index_mut(&mut self, key: Key) -> &mut Data<T> { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &mut self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +/// An entry in `DelayQueue` that has expired and been removed. +/// +/// Values are returned by [`DelayQueue::poll_expired`]. +/// +/// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired +#[derive(Debug)] +pub struct Expired<T> { + /// The data stored in the queue + data: T, + + /// The expiration time + deadline: Instant, + + /// The key associated with the entry + key: Key, +} + +/// Token to a value stored in a `DelayQueue`. +/// +/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] +/// documentation for more details. +/// +/// [`DelayQueue`]: struct@DelayQueue +/// [`DelayQueue::insert`]: method@DelayQueue::insert +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Key { + index: usize, +} + +// Whereas `Key` is given out to users that use `DelayQueue`, internally we use +// `KeyInternal` as the key type in order to make the logic of mapping between keys +// as a result of `compact` calls clearer. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct KeyInternal { + index: usize, +} + +#[derive(Debug)] +struct Stack<T> { + /// Head of the stack + head: Option<Key>, + _p: PhantomData<fn() -> T>, +} + +#[derive(Debug)] +struct Data<T> { + /// The data being stored in the queue and will be returned at the requested + /// instant. + inner: T, + + /// The instant at which the item is returned. + when: u64, + + /// Set to true when stored in the `expired` queue + expired: bool, + + /// Next entry in the stack + next: Option<Key>, + + /// Previous entry in the stack + prev: Option<Key>, +} + +/// Maximum number of entries the queue can handle +const MAX_ENTRIES: usize = (1 << 30) - 1; + +impl<T> DelayQueue<T> { + /// Creates a new, empty, `DelayQueue`. + /// + /// The queue will not allocate storage until items are inserted into it. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// let delay_queue: DelayQueue<u32> = DelayQueue::new(); + /// ``` + pub fn new() -> DelayQueue<T> { + DelayQueue::with_capacity(0) + } + + /// Creates a new, empty, `DelayQueue` with the specified capacity. + /// + /// The queue will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the queue will not allocate for + /// storage. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// # use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// // These insertions are done without further allocation + /// for i in 0..10 { + /// delay_queue.insert(i, Duration::from_secs(i)); + /// } + /// + /// // This will make the queue allocate additional storage + /// delay_queue.insert(11, Duration::from_secs(11)); + /// # } + /// ``` + pub fn with_capacity(capacity: usize) -> DelayQueue<T> { + DelayQueue { + wheel: Wheel::new(), + slab: SlabStorage::with_capacity(capacity), + expired: Stack::default(), + delay: None, + wheel_now: 0, + start: Instant::now(), + waker: None, + } + } + + /// Inserts `value` into the queue set to expire at a specific instant in + /// time. + /// + /// This function is identical to `insert`, but takes an `Instant` instead + /// of a `Duration`. + /// + /// `value` is stored in the queue until `when` is reached. At which point, + /// `value` will be returned from [`poll_expired`]. If `when` has already been + /// reached, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an argument to + /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once + /// `value` is removed from the queue either by calling [`poll_expired`] after + /// `when` is reached or by calling [`remove`]. At this point, the caller + /// must take care to not use the returned [`Key`] again as it may reference + /// a different item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert_at( + /// "foo", Instant::now() + Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + pub fn insert_at(&mut self, value: T, when: Instant) -> Key { + assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + // Insert the value in the store + let key = self.slab.insert(Data { + inner: value, + when, + expired: false, + next: None, + prev: None, + }); + + self.insert_idx(when, key); + + // Set a new delay if the current's deadline is later than the one of the new item + let should_set_delay = if let Some(ref delay) = self.delay { + let current_exp = self.normalize_deadline(delay.deadline()); + current_exp > when + } else { + true + }; + + if should_set_delay { + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + let delay_time = self.start + Duration::from_millis(when); + if let Some(ref mut delay) = &mut self.delay { + delay.as_mut().reset(delay_time); + } else { + self.delay = Some(Box::pin(sleep_until(delay_time))); + } + } + + key + } + + /// Attempts to pull out the next value of the delay queue, registering the + /// current task for wakeup if the value is not yet available, and returning + /// `None` if the queue is exhausted. + pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Expired<T>>> { + if !self + .waker + .as_ref() + .map(|w| w.will_wake(cx.waker())) + .unwrap_or(false) + { + self.waker = Some(cx.waker().clone()); + } + + let item = ready!(self.poll_idx(cx)); + Poll::Ready(item.map(|key| { + let data = self.slab.remove(&key); + debug_assert!(data.next.is_none()); + debug_assert!(data.prev.is_none()); + + Expired { + key, + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + })) + } + + /// Inserts `value` into the queue set to expire after the requested duration + /// elapses. + /// + /// This function is identical to `insert_at`, but takes a `Duration` + /// instead of an `Instant`. + /// + /// `value` is stored in the queue until `timeout` duration has + /// elapsed after `insert` was called. At that point, `value` will + /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of + /// zero, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an + /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a + /// token and is reused once `value` is removed from the queue + /// either by calling [`poll_expired`] after `timeout` has elapsed + /// or by calling [`remove`]. At this point, the caller must not + /// use the returned [`Key`] again as it may reference a different + /// item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum + /// duration supported by the timer in the current `Runtime`. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + pub fn insert(&mut self, value: T, timeout: Duration) -> Key { + self.insert_at(value, Instant::now() + timeout) + } + + fn insert_idx(&mut self, when: u64, key: Key) { + use self::wheel::{InsertError, Stack}; + + // Register the deadline with the timer wheel + match self.wheel.insert(when, key, &mut self.slab) { + Ok(_) => {} + Err((_, InsertError::Elapsed)) => { + self.slab[key].expired = true; + // The delay is already expired, store it in the expired queue + self.expired.push(key, &mut self.slab); + } + Err((_, err)) => panic!("invalid deadline; err={:?}", err), + } + } + + /// Removes the key from the expired queue or the timer wheel + /// depending on its expiration status. + /// + /// # Panics + /// + /// Panics if the key is not contained in the expired queue or the wheel. + fn remove_key(&mut self, key: &Key) { + use crate::time::wheel::Stack; + + // Special case the `expired` queue + if self.slab[*key].expired { + self.expired.remove(key, &mut self.slab); + } else { + self.wheel.remove(key, &mut self.slab); + } + } + + /// Removes the item associated with `key` from the queue. + /// + /// There must be an item associated with `key`. The function returns the + /// removed item as well as the `Instant` at which it will the delay will + /// have expired. + /// + /// # Panics + /// + /// The function panics if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + pub fn remove(&mut self, key: &Key) -> Expired<T> { + let prev_deadline = self.next_deadline(); + + self.remove_key(key); + let data = self.slab.remove(key); + + let next_deadline = self.next_deadline(); + if prev_deadline != next_deadline { + match (next_deadline, &mut self.delay) { + (None, _) => self.delay = None, + (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline), + (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))), + } + } + + Expired { + key: Key::new(key.index), + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + } + + /// Sets the delay of the item associated with `key` to expire at `when`. + /// + /// This function is identical to `reset` but takes an `Instant` instead of + /// a `Duration`. + /// + /// The item remains in the queue but the delay is set to expire at `when`. + /// If `when` is in the past, then the item is immediately made available to + /// the caller. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future or if `key` is + /// not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); + /// + /// // "foo" is now scheduled to be returned in 10 seconds + /// # } + /// ``` + pub fn reset_at(&mut self, key: &Key, when: Instant) { + self.remove_key(key); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + self.slab[*key].when = when; + self.slab[*key].expired = false; + + self.insert_idx(when, *key); + + let next_deadline = self.next_deadline(); + if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { + // This should awaken us if necessary (ie, if already expired) + delay.as_mut().reset(deadline); + } + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation. + /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab + /// to the number of elements still contained in it, because elements cannot be moved to a different + /// index. To decrease the capacity to the size of the slab use [`compact`]. + /// + /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is + /// shrunk in place. Repeated calls run in O(1) though. + /// + /// [`compact`]: method@Self::compact + pub fn shrink_to_fit(&mut self) { + self.slab.shrink_to_fit(); + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation, + /// to the number of elements that are contained in it. + /// + /// This methods runs in O(n). + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// let key1 = delay_queue.insert(5, Duration::from_secs(5)); + /// let key2 = delay_queue.insert(10, Duration::from_secs(10)); + /// let key3 = delay_queue.insert(15, Duration::from_secs(15)); + /// + /// delay_queue.remove(&key2); + /// + /// delay_queue.compact(); + /// assert_eq!(delay_queue.capacity(), 2); + /// # } + /// ``` + pub fn compact(&mut self) { + self.slab.compact(); + } + + /// Returns the next time to poll as determined by the wheel + fn next_deadline(&mut self) -> Option<Instant> { + self.wheel + .poll_at() + .map(|poll_at| self.start + Duration::from_millis(poll_at)) + } + + /// Sets the delay of the item associated with `key` to expire after + /// `timeout`. + /// + /// This function is identical to `reset_at` but takes a `Duration` instead + /// of an `Instant`. + /// + /// The item remains in the queue but the delay is set to expire after + /// `timeout`. If `timeout` is zero, then the item is immediately made + /// available to the caller. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum supported + /// duration or if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset(&key, Duration::from_secs(10)); + /// + /// // "foo"is now scheduled to be returned in 10 seconds + /// # } + /// ``` + pub fn reset(&mut self, key: &Key, timeout: Duration) { + self.reset_at(key, Instant::now() + timeout); + } + + /// Clears the queue, removing all items. + /// + /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. + /// + /// Note that this method has no effect on the allocated capacity. + /// + /// [`poll_expired`]: method@Self::poll_expired + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// assert!(!delay_queue.is_empty()); + /// + /// delay_queue.clear(); + /// + /// assert!(delay_queue.is_empty()); + /// # } + /// ``` + pub fn clear(&mut self) { + self.slab.clear(); + self.expired = Stack::default(); + self.wheel = Wheel::new(); + self.delay = None; + } + + /// Returns the number of elements the queue can hold without reallocating. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// + /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.capacity(), 10); + /// ``` + pub fn capacity(&self) -> usize { + self.slab.capacity() + } + + /// Returns the number of elements currently in the queue. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.len(), 0); + /// delay_queue.insert(3, Duration::from_secs(5)); + /// assert_eq!(delay_queue.len(), 1); + /// # } + /// ``` + pub fn len(&self) -> usize { + self.slab.len() + } + + /// Reserves capacity for at least `additional` more items to be queued + /// without allocating. + /// + /// `reserve` does nothing if the queue already has sufficient capacity for + /// `additional` more values. If more capacity is required, a new segment of + /// memory will be allocated and all existing values will be copied into it. + /// As such, if the queue is already very large, a call to `reserve` can end + /// up being expensive. + /// + /// The queue may reserve more than `additional` extra space in order to + /// avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity exceeds the maximum number of entries the + /// queue can contain. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("hello", Duration::from_secs(10)); + /// delay_queue.reserve(10); + /// + /// assert!(delay_queue.capacity() >= 11); + /// # } + /// ``` + pub fn reserve(&mut self, additional: usize) { + self.slab.reserve(additional); + } + + /// Returns `true` if there are no items in the queue. + /// + /// Note that this function returns `false` even if all items have not yet + /// expired and a call to `poll` will return `Poll::Pending`. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// assert!(delay_queue.is_empty()); + /// + /// delay_queue.insert("hello", Duration::from_secs(5)); + /// assert!(!delay_queue.is_empty()); + /// # } + /// ``` + pub fn is_empty(&self) -> bool { + self.slab.is_empty() + } + + /// Polls the queue, returning the index of the next slot in the slab that + /// should be returned. + /// + /// A slot should be returned when the associated deadline has been reached. + fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Key>> { + use self::wheel::Stack; + + let expired = self.expired.pop(&mut self.slab); + + if expired.is_some() { + return Poll::Ready(expired); + } + + loop { + if let Some(ref mut delay) = self.delay { + if !delay.is_elapsed() { + ready!(Pin::new(&mut *delay).poll(cx)); + } + + let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); + + self.wheel_now = now; + } + + // We poll the wheel to get the next value out before finding the next deadline. + let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab); + + self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when))); + + if let Some(idx) = wheel_idx { + return Poll::Ready(Some(idx)); + } + + if self.delay.is_none() { + return Poll::Ready(None); + } + } + } + + fn normalize_deadline(&self, when: Instant) -> u64 { + let when = if when < self.start { + 0 + } else { + crate::time::ms(when - self.start, crate::time::Round::Up) + }; + + cmp::max(when, self.wheel.elapsed()) + } +} + +// We never put `T` in a `Pin`... +impl<T> Unpin for DelayQueue<T> {} + +impl<T> Default for DelayQueue<T> { + fn default() -> DelayQueue<T> { + DelayQueue::new() + } +} + +impl<T> futures_core::Stream for DelayQueue<T> { + // DelayQueue seems much more specific, where a user may care that it + // has reached capacity, so return those errors instead of panicking. + type Item = Expired<T>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { + DelayQueue::poll_expired(self.get_mut(), cx) + } +} + +impl<T> wheel::Stack for Stack<T> { + type Owned = Key; + type Borrowed = Key; + type Store = SlabStorage<T>; + + fn is_empty(&self) -> bool { + self.head.is_none() + } + + fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { + // Ensure the entry is not already in a stack. + debug_assert!(store[item].next.is_none()); + debug_assert!(store[item].prev.is_none()); + + // Remove the old head entry + let old = self.head.take(); + + if let Some(idx) = old { + store[idx].prev = Some(item); + } + + store[item].next = old; + self.head = Some(item); + } + + fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> { + if let Some(key) = self.head { + self.head = store[key].next; + + if let Some(idx) = self.head { + store[idx].prev = None; + } + + store[key].next = None; + debug_assert!(store[key].prev.is_none()); + + Some(key) + } else { + None + } + } + + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { + let key = *item; + assert!(store.contains(item)); + + // Ensure that the entry is in fact contained by the stack + debug_assert!({ + // This walks the full linked list even if an entry is found. + let mut next = self.head; + let mut contains = false; + + while let Some(idx) = next { + let data = &store[idx]; + + if idx == *item { + debug_assert!(!contains); + contains = true; + } + + next = data.next; + } + + contains + }); + + if let Some(next) = store[key].next { + store[next].prev = store[key].prev; + } + + if let Some(prev) = store[key].prev { + store[prev].next = store[key].next; + } else { + self.head = store[key].next; + } + + store[key].next = None; + store[key].prev = None; + } + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { + store[*item].when + } +} + +impl<T> Default for Stack<T> { + fn default() -> Stack<T> { + Stack { + head: None, + _p: PhantomData, + } + } +} + +impl Key { + pub(crate) fn new(index: usize) -> Key { + Key { index } + } +} + +impl KeyInternal { + pub(crate) fn new(index: usize) -> KeyInternal { + KeyInternal { index } + } +} + +impl From<Key> for KeyInternal { + fn from(item: Key) -> Self { + KeyInternal::new(item.index) + } +} + +impl From<KeyInternal> for Key { + fn from(item: KeyInternal) -> Self { + Key::new(item.index) + } +} + +impl<T> Expired<T> { + /// Returns a reference to the inner value. + pub fn get_ref(&self) -> &T { + &self.data + } + + /// Returns a mutable reference to the inner value. + pub fn get_mut(&mut self) -> &mut T { + &mut self.data + } + + /// Consumes `self` and returns the inner value. + pub fn into_inner(self) -> T { + self.data + } + + /// Returns the deadline that the expiration was set to. + pub fn deadline(&self) -> Instant { + self.deadline + } + + /// Returns the key that the expiration is indexed by. + pub fn key(&self) -> Key { + self.key + } +} diff --git a/third_party/rust/tokio-util/src/time/mod.rs b/third_party/rust/tokio-util/src/time/mod.rs new file mode 100644 index 0000000000..2d34008360 --- /dev/null +++ b/third_party/rust/tokio-util/src/time/mod.rs @@ -0,0 +1,47 @@ +//! Additional utilities for tracking time. +//! +//! This module provides additional utilities for executing code after a set period +//! of time. Currently there is only one: +//! +//! * `DelayQueue`: A queue where items are returned once the requested delay +//! has expired. +//! +//! This type must be used from within the context of the `Runtime`. + +use std::time::Duration; + +mod wheel; + +pub mod delay_queue; + +#[doc(inline)] +pub use delay_queue::DelayQueue; + +// ===== Internal utils ===== + +enum Round { + Up, + Down, +} + +/// Convert a `Duration` to milliseconds, rounding up and saturating at +/// `u64::MAX`. +/// +/// The saturating is fine because `u64::MAX` milliseconds are still many +/// million years. +#[inline] +fn ms(duration: Duration, round: Round) -> u64 { + const NANOS_PER_MILLI: u32 = 1_000_000; + const MILLIS_PER_SEC: u64 = 1_000; + + // Round up. + let millis = match round { + Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, + Round::Down => duration.subsec_millis(), + }; + + duration + .as_secs() + .saturating_mul(MILLIS_PER_SEC) + .saturating_add(u64::from(millis)) +} diff --git a/third_party/rust/tokio-util/src/time/wheel/level.rs b/third_party/rust/tokio-util/src/time/wheel/level.rs new file mode 100644 index 0000000000..8ea30af30f --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/level.rs @@ -0,0 +1,253 @@ +use crate::time::wheel::Stack; + +use std::fmt; + +/// Wheel for a single level in the timer. This wheel contains 64 slots. +pub(crate) struct Level<T> { + level: usize, + + /// Bit field tracking which slots currently contain entries. + /// + /// Using a bit field to track slots that contain entries allows avoiding a + /// scan to find entries. This field is updated when entries are added or + /// removed from a slot. + /// + /// The least-significant bit represents slot zero. + occupied: u64, + + /// Slots + slot: [T; LEVEL_MULT], +} + +/// Indicates when a slot must be processed next. +#[derive(Debug)] +pub(crate) struct Expiration { + /// The level containing the slot. + pub(crate) level: usize, + + /// The slot index. + pub(crate) slot: usize, + + /// The instant at which the slot needs to be processed. + pub(crate) deadline: u64, +} + +/// Level multiplier. +/// +/// Being a power of 2 is very important. +const LEVEL_MULT: usize = 64; + +impl<T: Stack> Level<T> { + pub(crate) fn new(level: usize) -> Level<T> { + // Rust's derived implementations for arrays require that the value + // contained by the array be `Copy`. So, here we have to manually + // initialize every single slot. + macro_rules! s { + () => { + T::default() + }; + } + + Level { + level, + occupied: 0, + slot: [ + // It does not look like the necessary traits are + // derived for [T; 64]. + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + ], + } + } + + /// Finds the slot that needs to be processed next and returns the slot and + /// `Instant` at which this slot must be processed. + pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> { + // Use the `occupied` bit field to get the index of the next slot that + // needs to be processed. + let slot = match self.next_occupied_slot(now) { + Some(slot) => slot, + None => return None, + }; + + // From the slot index, calculate the `Instant` at which it needs to be + // processed. This value *must* be in the future with respect to `now`. + + let level_range = level_range(self.level); + let slot_range = slot_range(self.level); + + // TODO: This can probably be simplified w/ power of 2 math + let level_start = now - (now % level_range); + let deadline = level_start + slot as u64 * slot_range; + + debug_assert!( + deadline >= now, + "deadline={}; now={}; level={}; slot={}; occupied={:b}", + deadline, + now, + self.level, + slot, + self.occupied + ); + + Some(Expiration { + level: self.level, + slot, + deadline, + }) + } + + fn next_occupied_slot(&self, now: u64) -> Option<usize> { + if self.occupied == 0 { + return None; + } + + // Get the slot for now using Maths + let now_slot = (now / slot_range(self.level)) as usize; + let occupied = self.occupied.rotate_right(now_slot as u32); + let zeros = occupied.trailing_zeros() as usize; + let slot = (zeros + now_slot) % 64; + + Some(slot) + } + + pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].push(item, store); + self.occupied |= occupied_bit(slot); + } + + pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].remove(item, store); + + if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + // Unset the bit + self.occupied ^= occupied_bit(slot); + } + } + + pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> { + let ret = self.slot[slot].pop(store); + + if ret.is_some() && self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + self.occupied ^= occupied_bit(slot); + } + + ret + } +} + +impl<T> fmt::Debug for Level<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Level") + .field("occupied", &self.occupied) + .finish() + } +} + +fn occupied_bit(slot: usize) -> u64 { + 1 << slot +} + +fn slot_range(level: usize) -> u64 { + LEVEL_MULT.pow(level as u32) as u64 +} + +fn level_range(level: usize) -> u64 { + LEVEL_MULT as u64 * slot_range(level) +} + +/// Convert a duration (milliseconds) and a level to a slot position +fn slot_for(duration: u64, level: usize) -> usize { + ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_slot_for() { + for pos in 0..64 { + assert_eq!(pos as usize, slot_for(pos, 0)); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!(pos as usize, slot_for(a as u64, level)); + } + } + } +} diff --git a/third_party/rust/tokio-util/src/time/wheel/mod.rs b/third_party/rust/tokio-util/src/time/wheel/mod.rs new file mode 100644 index 0000000000..4191e401df --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/mod.rs @@ -0,0 +1,314 @@ +mod level; +pub(crate) use self::level::Expiration; +use self::level::Level; + +mod stack; +pub(crate) use self::stack::Stack; + +use std::borrow::Borrow; +use std::fmt::Debug; +use std::usize; + +/// Timing wheel implementation. +/// +/// This type provides the hashed timing wheel implementation that backs `Timer` +/// and `DelayQueue`. +/// +/// The structure is generic over `T: Stack`. This allows handling timeout data +/// being stored on the heap or in a slab. In order to support the latter case, +/// the slab must be passed into each function allowing the implementation to +/// lookup timer entries. +/// +/// See `Timer` documentation for some implementation notes. +#[derive(Debug)] +pub(crate) struct Wheel<T> { + /// The number of milliseconds elapsed since the wheel started. + elapsed: u64, + + /// Timer wheel. + /// + /// Levels: + /// + /// * 1 ms slots / 64 ms range + /// * 64 ms slots / ~ 4 sec range + /// * ~ 4 sec slots / ~ 4 min range + /// * ~ 4 min slots / ~ 4 hr range + /// * ~ 4 hr slots / ~ 12 day range + /// * ~ 12 day slots / ~ 2 yr range + levels: Vec<Level<T>>, +} + +/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots +/// each, the timer is able to track time up to 2 years into the future with a +/// precision of 1 millisecond. +const NUM_LEVELS: usize = 6; + +/// The maximum duration of a delay +const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; + +#[derive(Debug)] +pub(crate) enum InsertError { + Elapsed, + Invalid, +} + +impl<T> Wheel<T> +where + T: Stack, +{ + /// Create a new timing wheel + pub(crate) fn new() -> Wheel<T> { + let levels = (0..NUM_LEVELS).map(Level::new).collect(); + + Wheel { elapsed: 0, levels } + } + + /// Return the number of milliseconds that have elapsed since the timing + /// wheel's creation. + pub(crate) fn elapsed(&self) -> u64 { + self.elapsed + } + + /// Insert an entry into the timing wheel. + /// + /// # Arguments + /// + /// * `when`: is the instant at which the entry should be fired. It is + /// represented as the number of milliseconds since the creation + /// of the timing wheel. + /// + /// * `item`: The item to insert into the wheel. + /// + /// * `store`: The slab or `()` when using heap storage. + /// + /// # Return + /// + /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. + /// + /// `Err(Elapsed)` indicates that `when` represents an instant that has + /// already passed. In this case, the caller should fire the timeout + /// immediately. + /// + /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. + pub(crate) fn insert( + &mut self, + when: u64, + item: T::Owned, + store: &mut T::Store, + ) -> Result<(), (T::Owned, InsertError)> { + if when <= self.elapsed { + return Err((item, InsertError::Elapsed)); + } else if when - self.elapsed > MAX_DURATION { + return Err((item, InsertError::Invalid)); + } + + // Get the level at which the entry should be stored + let level = self.level_for(when); + + self.levels[level].add_entry(when, item, store); + + debug_assert!({ + self.levels[level] + .next_expiration(self.elapsed) + .map(|e| e.deadline >= self.elapsed) + .unwrap_or(true) + }); + + Ok(()) + } + + /// Remove `item` from the timing wheel. + pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { + let when = T::when(item, store); + + assert!( + self.elapsed <= when, + "elapsed={}; when={}", + self.elapsed, + when + ); + + let level = self.level_for(when); + + self.levels[level].remove_entry(when, item, store); + } + + /// Instant at which to poll + pub(crate) fn poll_at(&self) -> Option<u64> { + self.next_expiration().map(|expiration| expiration.deadline) + } + + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> { + loop { + let expiration = self.next_expiration().and_then(|expiration| { + if expiration.deadline > now { + None + } else { + Some(expiration) + } + }); + + match expiration { + Some(ref expiration) => { + if let Some(item) = self.poll_expiration(expiration, store) { + return Some(item); + } + + self.set_elapsed(expiration.deadline); + } + None => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + return None; + } + } + } + } + + /// Returns the instant at which the next timeout expires. + fn next_expiration(&self) -> Option<Expiration> { + // Check all levels + for level in 0..NUM_LEVELS { + if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { + // There cannot be any expirations at a higher level that happen + // before this one. + debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); + + return Some(expiration); + } + } + + None + } + + /// Used for debug assertions + fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { + let mut res = true; + + for l2 in start_level..NUM_LEVELS { + if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { + if e2.deadline < before { + res = false; + } + } + } + + res + } + + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// return it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn poll_expiration( + &mut self, + expiration: &Expiration, + store: &mut T::Store, + ) -> Option<T::Owned> { + while let Some(item) = self.pop_entry(expiration, store) { + if expiration.level == 0 { + debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); + + return Some(item); + } else { + let when = T::when(item.borrow(), store); + + let next_level = expiration.level - 1; + + self.levels[next_level].add_entry(when, item, store); + } + } + + None + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } + } + + fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> { + self.levels[expiration.level].pop_entry_slot(expiration.slot, store) + } + + fn level_for(&self, when: u64) -> usize { + level_for(self.elapsed, when) + } +} + +fn level_for(elapsed: u64, when: u64) -> usize { + const SLOT_MASK: u64 = (1 << 6) - 1; + + // Mask in the trailing bits ignored by the level calculation in order to cap + // the possible leading zeros + let masked = elapsed ^ when | SLOT_MASK; + + let leading_zeros = masked.leading_zeros() as usize; + let significant = 63 - leading_zeros; + significant / 6 +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_level_for() { + for pos in 0..64 { + assert_eq!( + 0, + level_for(0, pos), + "level_for({}) -- binary = {:b}", + pos, + pos + ); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + + if pos > level { + let a = a - 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + + if pos < 64 { + let a = a + 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + } + } + } +} diff --git a/third_party/rust/tokio-util/src/time/wheel/stack.rs b/third_party/rust/tokio-util/src/time/wheel/stack.rs new file mode 100644 index 0000000000..c87adcafda --- /dev/null +++ b/third_party/rust/tokio-util/src/time/wheel/stack.rs @@ -0,0 +1,28 @@ +use std::borrow::Borrow; +use std::cmp::Eq; +use std::hash::Hash; + +/// Abstracts the stack operations needed to track timeouts. +pub(crate) trait Stack: Default { + /// Type of the item stored in the stack + type Owned: Borrow<Self::Borrowed>; + + /// Borrowed item + type Borrowed: Eq + Hash; + + /// Item storage, this allows a slab to be used instead of just the heap + type Store; + + /// Returns `true` if the stack is empty + fn is_empty(&self) -> bool; + + /// Push an item onto the stack + fn push(&mut self, item: Self::Owned, store: &mut Self::Store); + + /// Pop an item from the stack + fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>; + + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; +} diff --git a/third_party/rust/tokio-util/src/udp/frame.rs b/third_party/rust/tokio-util/src/udp/frame.rs new file mode 100644 index 0000000000..d900fd7691 --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/frame.rs @@ -0,0 +1,245 @@ +use crate::codec::{Decoder, Encoder}; + +use futures_core::Stream; +use tokio::{io::ReadBuf, net::UdpSocket}; + +use bytes::{BufMut, BytesMut}; +use futures_core::ready; +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{ + borrow::Borrow, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; +use std::{io, mem::MaybeUninit}; + +/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using +/// the `Encoder` and `Decoder` traits to encode and decode frames. +/// +/// Raw UDP sockets work with datagrams, but higher-level code usually wants to +/// batch these into meaningful chunks, called "frames". This method layers +/// framing on top of this socket by using the `Encoder` and `Decoder` traits to +/// handle encoding and decoding of messages frames. Note that the incoming and +/// outgoing frame types may be distinct. +/// +/// This function returns a *single* object that is both [`Stream`] and [`Sink`]; +/// grouping this into a single object is often useful for layering things which +/// require both read and write access to the underlying object. +/// +/// If you want to work more directly with the streams and sink, consider +/// calling [`split`] on the `UdpFramed` returned by this method, which will break +/// them into separate objects, allowing them to interact more easily. +/// +/// [`Stream`]: futures_core::Stream +/// [`Sink`]: futures_sink::Sink +/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split +#[must_use = "sinks do nothing unless polled"] +#[derive(Debug)] +pub struct UdpFramed<C, T = UdpSocket> { + socket: T, + codec: C, + rd: BytesMut, + wr: BytesMut, + out_addr: SocketAddr, + flushed: bool, + is_readable: bool, + current_addr: Option<SocketAddr>, +} + +const INITIAL_RD_CAPACITY: usize = 64 * 1024; +const INITIAL_WR_CAPACITY: usize = 8 * 1024; + +impl<C, T> Unpin for UdpFramed<C, T> {} + +impl<C, T> Stream for UdpFramed<C, T> +where + T: Borrow<UdpSocket>, + C: Decoder, +{ + type Item = Result<(C::Item, SocketAddr), C::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let pin = self.get_mut(); + + pin.rd.reserve(INITIAL_RD_CAPACITY); + + loop { + // Are there still bytes left in the read buffer to decode? + if pin.is_readable { + if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { + let current_addr = pin + .current_addr + .expect("will always be set before this line is called"); + + return Poll::Ready(Some(Ok((frame, current_addr)))); + } + + // if this line has been reached then decode has returned `None`. + pin.is_readable = false; + pin.rd.clear(); + } + + // We're out of data. Try and fetch more data to decode + let addr = unsafe { + // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be + // writing to it via `poll_recv_from` and therefore initializing the memory. + let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]); + let mut read = ReadBuf::uninit(buf); + let ptr = read.filled().as_ptr(); + let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); + + assert_eq!(ptr, read.filled().as_ptr()); + let addr = res?; + pin.rd.advance_mut(read.filled().len()); + addr + }; + + pin.current_addr = Some(addr); + pin.is_readable = true; + } + } +} + +impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> +where + T: Borrow<UdpSocket>, + C: Encoder<I>, +{ + type Error = C::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if !self.flushed { + match self.poll_flush(cx)? { + Poll::Ready(()) => {} + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { + let (frame, out_addr) = item; + + let pin = self.get_mut(); + + pin.codec.encode(frame, &mut pin.wr)?; + pin.out_addr = out_addr; + pin.flushed = false; + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if self.flushed { + return Poll::Ready(Ok(())); + } + + let Self { + ref socket, + ref mut out_addr, + ref mut wr, + .. + } = *self; + + let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; + + let wrote_all = n == self.wr.len(); + self.wr.clear(); + self.flushed = true; + + let res = if wrote_all { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "failed to write entire datagram to socket", + ) + .into()) + }; + + Poll::Ready(res) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} + +impl<C, T> UdpFramed<C, T> +where + T: Borrow<UdpSocket>, +{ + /// Create a new `UdpFramed` backed by the given socket and codec. + /// + /// See struct level documentation for more details. + pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { + Self { + socket, + codec, + out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), + rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), + wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), + flushed: true, + is_readable: false, + current_addr: None, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_ref(&self) -> &T { + &self.socket + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.socket + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &C { + &self.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `UdpFramed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut C { + &mut self.codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.rd + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.rd + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + pub fn into_inner(self) -> T { + self.socket + } +} diff --git a/third_party/rust/tokio-util/src/udp/mod.rs b/third_party/rust/tokio-util/src/udp/mod.rs new file mode 100644 index 0000000000..f88ea030aa --- /dev/null +++ b/third_party/rust/tokio-util/src/udp/mod.rs @@ -0,0 +1,4 @@ +//! UDP framing + +mod frame; +pub use frame::UdpFramed; diff --git a/third_party/rust/tokio-util/tests/_require_full.rs b/third_party/rust/tokio-util/tests/_require_full.rs new file mode 100644 index 0000000000..045934d175 --- /dev/null +++ b/third_party/rust/tokio-util/tests/_require_full.rs @@ -0,0 +1,2 @@ +#![cfg(not(feature = "full"))] +compile_error!("run tokio-util tests with `--features full`"); diff --git a/third_party/rust/tokio-util/tests/codecs.rs b/third_party/rust/tokio-util/tests/codecs.rs new file mode 100644 index 0000000000..f9a780140a --- /dev/null +++ b/third_party/rust/tokio-util/tests/codecs.rs @@ -0,0 +1,464 @@ +#![warn(rust_2018_idioms)] + +use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec}; + +use bytes::{BufMut, Bytes, BytesMut}; + +#[test] +fn bytes_decoder() { + let mut codec = BytesCodec::new(); + let buf = &mut BytesMut::new(); + buf.put_slice(b"abc"); + assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"a"); + assert_eq!("a", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn bytes_encoder() { + let mut codec = BytesCodec::new(); + + // Default capacity of BytesMut + #[cfg(target_pointer_width = "64")] + const INLINE_CAP: usize = 4 * 8 - 1; + #[cfg(target_pointer_width = "32")] + const INLINE_CAP: usize = 4 * 4 - 1; + + let mut buf = BytesMut::new(); + codec + .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) + .unwrap(); + + // Default capacity of Framed Read + const INITIAL_CAPACITY: usize = 8 * 1024; + + let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); + codec + .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) + .unwrap(); + codec + .encode(BytesMut::from(&b"hello"[..]), &mut buf) + .unwrap(); +} + +#[test] +fn lines_decoder() { + let mut codec = LinesCodec::new(); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); + assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn lines_decoder_max_length() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 2", line); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 4", line); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let line = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("\rk", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Line that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"line 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"\n"); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_newline_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"\nworld"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +// Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) +#[test] +fn lines_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +// Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555) +#[test] +fn lines_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_encoder() { + let mut codec = LinesCodec::new(); + let mut buf = BytesMut::new(); + + codec.encode("line 1", &mut buf).unwrap(); + assert_eq!("line 1\n", buf); + + codec.encode("line 2", &mut buf).unwrap(); + assert_eq!("line 1\nline 2\n", buf); +} + +#[test] +fn any_delimiters_decoder_any_character() { + let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec()); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); + assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn any_delimiters_decoder_max_length() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n"); + + assert!(codec.decode(buf).is_err()); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 2", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 3", chunk); + + // \r\n cause empty chunk + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 4", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let chunk = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("k", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Delimiter that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbcc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"chunk 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b","); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} +#[test] +fn any_delimiter_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_delimiter_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b",world"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +#[test] +fn any_delimiter_encoder() { + let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec()); + let mut buf = BytesMut::new(); + + codec.encode("chunk 1", &mut buf).unwrap(); + assert_eq!("chunk 1;--;", buf); + + codec.encode("chunk 2", &mut buf).unwrap(); + assert_eq!("chunk 1;--;chunk 2;--;", buf); +} diff --git a/third_party/rust/tokio-util/tests/context.rs b/third_party/rust/tokio-util/tests/context.rs new file mode 100644 index 0000000000..7510f36fd1 --- /dev/null +++ b/third_party/rust/tokio-util/tests/context.rs @@ -0,0 +1,24 @@ +#![cfg(feature = "rt")] +#![warn(rust_2018_idioms)] + +use tokio::runtime::Builder; +use tokio::time::*; +use tokio_util::context::RuntimeExt; + +#[test] +fn tokio_context_with_another_runtime() { + let rt1 = Builder::new_multi_thread() + .worker_threads(1) + // no timer! + .build() + .unwrap(); + let rt2 = Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + // Without the `HandleExt.wrap()` there would be a panic because there is + // no timer running, since it would be referencing runtime r1. + let _ = rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await })); +} diff --git a/third_party/rust/tokio-util/tests/framed.rs b/third_party/rust/tokio-util/tests/framed.rs new file mode 100644 index 0000000000..ec8cdf00d0 --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed.rs @@ -0,0 +1,152 @@ +#![warn(rust_2018_idioms)] + +use tokio_stream::StreamExt; +use tokio_test::assert_ok; +use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; + +use bytes::{Buf, BufMut, BytesMut}; +use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Encode and decode u32 values. +#[derive(Default)] +struct U32Codec { + read_bytes: usize, +} + +impl Decoder for U32Codec { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + self.read_bytes += 4; + Ok(Some(n)) + } +} + +impl Encoder<u32> for U32Codec { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +/// Encode and decode u64 values. +#[derive(Default)] +struct U64Codec { + read_bytes: usize, +} + +impl Decoder for U64Codec { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + self.read_bytes += 8; + Ok(Some(n)) + } +} + +impl Encoder<u64> for U64Codec { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +/// This value should never be used +struct DontReadIntoThis; + +impl Read for DontReadIntoThis { + fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { + Err(io::Error::new( + io::ErrorKind::Other, + "Read into something you weren't supposed to.", + )) + } +} + +impl tokio::io::AsyncRead for DontReadIntoThis { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unreachable!() + } +} + +#[tokio::test] +async fn can_read_from_existing_buf() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); +} + +#[tokio::test] +async fn can_read_from_existing_buf_after_codec_changed() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); + + let mut framed = framed.map_codec(|codec| U64Codec { + read_bytes: codec.read_bytes, + }); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 84); + assert_eq!(framed.codec().read_bytes, 12); +} + +#[test] +fn external_buf_grows_to_init() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); +} + +#[test] +fn external_buf_does_not_shrink() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); +} diff --git a/third_party/rust/tokio-util/tests/framed_read.rs b/third_party/rust/tokio-util/tests/framed_read.rs new file mode 100644 index 0000000000..2a9e27e22f --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_read.rs @@ -0,0 +1,339 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_test::assert_ready; +use tokio_test::task; +use tokio_util::codec::{Decoder, FramedRead}; + +use bytes::{Buf, BytesMut}; +use futures::Stream; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Decoder; + +impl Decoder for U32Decoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + Ok(Some(n)) + } +} + +struct U64Decoder; + +impl Decoder for U64Decoder { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + Ok(Some(n)) + } +} + +#[test] +fn read_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_across_packets() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_in_packet_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0x04); + + let mut framed = framed.map_decoder(|_| U64Decoder); + assert_read!(pin!(framed).poll_next(cx), 0x08); + + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_partial_then_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_err() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_would_block_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn huge_size() { + let mut task = task::spawn(()); + let data = &[0; 32 * 1024][..]; + let mut framed = FramedRead::new(data, BigDecoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); + + struct BigDecoder; + + impl Decoder for BigDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 32 * 1024 { + return Ok(None); + } + buf.advance(32 * 1024); + Ok(Some(0)) + } + } +} + +#[test] +fn data_remaining_is_error() { + let mut task = task::spawn(()); + let slice = &[0; 5][..]; + let mut framed = FramedRead::new(slice, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); + }); +} + +#[test] +fn multi_frames_on_eof() { + let mut task = task::spawn(()); + struct MyDecoder(Vec<u32>); + + impl Decoder for MyDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + unreachable!(); + } + + fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + if self.0.is_empty() { + return Ok(None); + } + + Ok(Some(self.0.remove(0))) + } + } + + let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_eof_then_resume() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x03".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +// ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + use io::ErrorKind::WouldBlock; + + match self.calls.pop_front() { + Some(Ok(data)) => { + debug_assert!(buf.remaining() >= data.len()); + buf.put_slice(&data); + Ready(Ok(())) + } + Some(Err(ref e)) if e.kind() == WouldBlock => Pending, + Some(Err(e)) => Ready(Err(e)), + None => Ready(Ok(())), + } + } +} diff --git a/third_party/rust/tokio-util/tests/framed_stream.rs b/third_party/rust/tokio-util/tests/framed_stream.rs new file mode 100644 index 0000000000..76d8af7b7d --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_stream.rs @@ -0,0 +1,38 @@ +use futures_core::stream::Stream; +use std::{io, pin::Pin}; +use tokio_test::{assert_ready, io::Builder, task}; +use tokio_util::codec::{BytesCodec, FramedRead}; + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +#[tokio::test] +async fn return_none_after_error() { + let mut io = FramedRead::new( + Builder::new() + .read(b"abcdef") + .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) + .read(b"more data") + .build(), + BytesCodec::new(), + ); + + let mut task = task::spawn(()); + + task.enter(|cx, _| { + assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); + assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); + assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); + assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); + }) +} diff --git a/third_party/rust/tokio-util/tests/framed_write.rs b/third_party/rust/tokio-util/tests/framed_write.rs new file mode 100644 index 0000000000..259d9b0c9f --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_write.rs @@ -0,0 +1,211 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::AsyncWrite; +use tokio_test::{assert_ready, task}; +use tokio_util::codec::{Encoder, FramedWrite}; + +use bytes::{BufMut, BytesMut}; +use futures_sink::Sink; +use std::collections::VecDeque; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Encoder; + +impl Encoder<u32> for U32Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +struct U64Encoder; + +impl Encoder<u64> for U64Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +#[test] +fn write_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(1).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(2).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_multi_frame_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x04).is_ok()); + + let mut framed = framed.map_encoder(|_| U64Encoder); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x08).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_hits_backpressure() { + const ITER: usize = 2 * 1024; + + let mut mock = mock! { + // Block the `ITER`th write + Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), + Ok(b"".to_vec()), + }; + + for i in 0..=ITER { + let mut b = BytesMut::with_capacity(4); + b.put_u32(i as u32); + + // Append to the end + match mock.calls.back_mut().unwrap() { + Ok(ref mut data) => { + // Write in 2kb chunks + if data.len() < ITER { + data.extend_from_slice(&b[..]); + continue; + } // else fall through and create a new buffer + } + _ => unreachable!(), + } + + // Push a new new chunk + mock.calls.push_back(Ok(b[..].to_vec())); + } + // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer + assert_eq!(mock.calls.len(), 6); + + let mut task = task::spawn(()); + let mut framed = FramedWrite::new(mock, U32Encoder); + task.enter(|cx, _| { + // Send 8KB. This fills up FramedWrite2 buffer + for i in 0..ITER { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(i as u32).is_ok()); + } + + // Now we poll_ready which forces a flush. The mock pops the front message + // and decides to block. + assert!(pin!(framed).poll_ready(cx).is_pending()); + + // We poll again, forcing another flush, which this time succeeds + // The whole 8KB buffer is flushed + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + + // Send more data. This matches the final message expected by the mock + assert!(pin!(framed).start_send(ITER as u32).is_ok()); + + // Flush the rest of the buffer + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + // Ensure the mock is empty + assert_eq!(0, framed.get_ref().calls.len()); + }) +} + +// // ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl Write for Mock { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + match self.calls.pop_front() { + Some(Ok(data)) => { + assert!(src.len() >= data.len()); + assert_eq!(&data[..], &src[..data.len()]); + Ok(data.len()) + } + Some(Err(e)) => Err(e), + None => panic!("unexpected write; {:?}", src), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match Pin::get_mut(self).write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match Pin::get_mut(self).flush() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + unimplemented!() + } +} diff --git a/third_party/rust/tokio-util/tests/io_reader_stream.rs b/third_party/rust/tokio-util/tests/io_reader_stream.rs new file mode 100644 index 0000000000..e30cd85164 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_reader_stream.rs @@ -0,0 +1,65 @@ +#![warn(rust_2018_idioms)] + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_stream::StreamExt; + +/// produces at most `remaining` zeros, that returns error. +/// each time it reads at most 31 byte. +struct Reader { + remaining: usize, +} + +impl AsyncRead for Reader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + let this = Pin::into_inner(self); + assert_ne!(buf.remaining(), 0); + if this.remaining > 0 { + let n = std::cmp::min(this.remaining, buf.remaining()); + let n = std::cmp::min(n, 31); + for x in &mut buf.initialize_unfilled_to(n)[..n] { + *x = 0; + } + buf.advance(n); + this.remaining -= n; + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) + } + } +} + +#[tokio::test] +async fn correct_behavior_on_errors() { + let reader = Reader { remaining: 8000 }; + let mut stream = tokio_util::io::ReaderStream::new(reader); + let mut zeros_received = 0; + let mut had_error = false; + loop { + let item = stream.next().await.unwrap(); + println!("{:?}", item); + match item { + Ok(bytes) => { + let bytes = &*bytes; + for byte in bytes { + assert_eq!(*byte, 0); + zeros_received += 1; + } + } + Err(_) => { + assert!(!had_error); + had_error = true; + break; + } + } + } + + assert!(had_error); + assert_eq!(zeros_received, 8000); + assert!(stream.next().await.is_none()); +} diff --git a/third_party/rust/tokio-util/tests/io_stream_reader.rs b/third_party/rust/tokio-util/tests/io_stream_reader.rs new file mode 100644 index 0000000000..59759941c5 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_stream_reader.rs @@ -0,0 +1,35 @@ +#![warn(rust_2018_idioms)] + +use bytes::Bytes; +use tokio::io::AsyncReadExt; +use tokio_stream::iter; +use tokio_util::io::StreamReader; + +#[tokio::test] +async fn test_stream_reader() -> std::io::Result<()> { + let stream = iter(vec![ + std::io::Result::Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[0, 1, 2, 3])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[4, 5, 6, 7])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[8, 9, 10, 11])), + Ok(Bytes::from_static(&[])), + ]); + + let mut read = StreamReader::new(stream); + + let mut buf = [0; 5]; + read.read_exact(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + assert_eq!(read.read(&mut buf).await?, 3); + assert_eq!(&buf[..3], [5, 6, 7]); + + assert_eq!(read.read(&mut buf).await?, 4); + assert_eq!(&buf[..4], [8, 9, 10, 11]); + + assert_eq!(read.read(&mut buf).await?, 0); + + Ok(()) +} diff --git a/third_party/rust/tokio-util/tests/io_sync_bridge.rs b/third_party/rust/tokio-util/tests/io_sync_bridge.rs new file mode 100644 index 0000000000..0d420857b5 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_sync_bridge.rs @@ -0,0 +1,43 @@ +#![cfg(feature = "io-util")] + +use std::error::Error; +use std::io::{Cursor, Read, Result as IoResult}; +use tokio::io::AsyncRead; +use tokio_util::io::SyncIoBridge; + +async fn test_reader_len( + r: impl AsyncRead + Unpin + Send + 'static, + expected_len: usize, +) -> IoResult<()> { + let mut r = SyncIoBridge::new(r); + let res = tokio::task::spawn_blocking(move || { + let mut buf = Vec::new(); + r.read_to_end(&mut buf)?; + Ok::<_, std::io::Error>(buf) + }) + .await?; + assert_eq!(res?.len(), expected_len); + Ok(()) +} + +#[tokio::test] +async fn test_async_read_to_sync() -> Result<(), Box<dyn Error>> { + test_reader_len(tokio::io::empty(), 0).await?; + let buf = b"hello world"; + test_reader_len(Cursor::new(buf), buf.len()).await?; + Ok(()) +} + +#[tokio::test] +async fn test_async_write_to_sync() -> Result<(), Box<dyn Error>> { + let mut dest = Vec::new(); + let src = b"hello world"; + let dest = tokio::task::spawn_blocking(move || -> Result<_, String> { + let mut w = SyncIoBridge::new(Cursor::new(&mut dest)); + std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?; + Ok(dest) + }) + .await??; + assert_eq!(dest.as_slice(), src); + Ok(()) +} diff --git a/third_party/rust/tokio-util/tests/length_delimited.rs b/third_party/rust/tokio-util/tests/length_delimited.rs new file mode 100644 index 0000000000..126e41b5cd --- /dev/null +++ b/third_party/rust/tokio-util/tests/length_delimited.rs @@ -0,0 +1,779 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_test::task; +use tokio_test::{ + assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, +}; +use tokio_util::codec::*; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{pin_mut, Sink, Stream}; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::*; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_next_eq { + ($io:ident, $expect:expr) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), + Some(Err(e)) => panic!("error = {:?}", e), + None => panic!("none"), + } + }); + }}; +} + +macro_rules! assert_next_pending { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(e))) => panic!("error = {:?}", e), + Ready(None) => panic!("done"), + Pending => {} + }); + }}; +} + +macro_rules! assert_next_err { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(_))) => {} + Ready(None) => panic!("done"), + Pending => panic!("pending"), + }); + }}; +} + +macro_rules! assert_done { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => panic!("value = {:?}", v), + Some(Err(e)) => panic!("error = {:?}", e), + None => {} + } + }); + }}; +} + +#[test] +fn read_empty_io_yields_nothing() { + let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); + pin_mut!(io); + + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_read(mock! { + data(b"\x09\x00\x00\x00abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_native_endian() { + let d = if cfg!(target_endian = "big") { + b"\x00\x00\x00\x09abcdefghi" + } else { + b"\x09\x00\x00\x00abcdefghi" + }; + let io = length_delimited::Builder::new() + .native_endian() + .new_read(mock! { + data(d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); + d.extend_from_slice(b"\x00\x00\x00\x03123"); + d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); + + let io = FramedRead::new( + mock! { + data(&d), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + data(b"\x00\x00\x00\x0312"), + data(b"3\x00\x00\x00\x0bhello world"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + data(b"\x00\x00\x00\x0312"), + Pending, + data(b"3\x00\x00\x00\x0bhello world"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_incomplete_head() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_incomplete_head_multi() { + let io = FramedRead::new( + mock! { + Pending, + data(b"\x00"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_incomplete_payload() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09ab"), + Pending, + data(b"cd"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + io.decoder_mut().set_max_frame_length(5); + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcd"), + Pending, + data(b"efghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_pending!(io); + io.decoder_mut().set_max_frame_length(5); + assert_next_eq!(io, b"abcdefghi"); + assert_next_err!(io); +} + +#[test] +fn read_one_byte_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_read(mock! { + data(b"\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_header_offset() { + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(4) + .new_read(mock! { + data(b"zzzz\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_skip_none_adjusted() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"xx\x00\x09abcdefghi"); + d.extend_from_slice(b"yy\x00\x03123"); + d.extend_from_slice(b"zz\x00\x0bhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(2) + .num_skip(0) + .length_adjustment(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"xx\x00\x09abcdefghi"); + assert_next_eq!(io, b"yy\x00\x03123"); + assert_next_eq!(io, b"zz\x00\x0bhello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_length_adjusted() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x00\x0b\x0cHello world"); + + let io = length_delimited::Builder::new() + .length_field_offset(0) + .length_field_length(3) + .length_adjustment(0) + .num_skip(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"Hello world"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_length_includes_head() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x0babcdefghi"); + d.extend_from_slice(b"\x00\x05123"); + d.extend_from_slice(b"\x00\x0dhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_adjustment(-2) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn write_single_frame_length_adjusted() { + let io = length_delimited::Builder::new() + .length_adjustment(-2) + .new_write(mock! { + data(b"\x00\x00\x00\x0b"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_nothing_yields_nothing() { + let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.poll_flush(cx)); + }); +} + +#[test] +fn write_single_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + data(b"\x00\x00\x00\x03"), + data(b"123"), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_multi_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + data(b"\x00\x00\x00\x03"), + data(b"123"), + flush(), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_would_block() { + let io = FramedWrite::new( + mock! { + Pending, + data(b"\x00\x00"), + Pending, + data(b"\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + assert_pending!(io.as_mut().poll_flush(cx)); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_write(mock! { + data(b"\x09\x00\x00\x00"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_with_short_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_write(mock! { + data(b"\x09"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"abcdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"ab"), + Pending, + data(b"cdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_zero() { + let io = length_delimited::Builder::new().new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_err!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn encode_overflow() { + // Test reproducing tokio-rs/tokio#681. + let mut codec = length_delimited::Builder::new().new_codec(); + let mut buf = BytesMut::with_capacity(1024); + + // Put some data into the buffer without resizing it to hold more. + let some_as = std::iter::repeat(b'a').take(1024).collect::<Vec<_>>(); + buf.put_slice(&some_as[..]); + + // Trying to encode the length header should resize the buffer if it won't fit. + codec.encode(Bytes::from("hello"), &mut buf).unwrap(); +} + +// ===== Test utils ===== + +struct Mock { + calls: VecDeque<Poll<io::Result<Op>>>, +} + +enum Op { + Data(Vec<u8>), + Flush, +} + +use self::Op::*; + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + debug_assert!(dst.remaining() >= data.len()); + dst.put_slice(&data); + Ready(Ok(())) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } +} + +impl AsyncWrite for Mock { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + src: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + let len = data.len(); + assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); + assert_eq!(&data[..], &src[..len]); + Ready(Ok(len)) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(0)), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ready(Ok(())) + } +} + +impl<'a> From<&'a [u8]> for Op { + fn from(src: &'a [u8]) -> Op { + Op::Data(src.into()) + } +} + +impl From<Vec<u8>> for Op { + fn from(src: Vec<u8>) -> Op { + Op::Data(src) + } +} + +fn data(bytes: &[u8]) -> Poll<io::Result<Op>> { + Ready(Ok(bytes.into())) +} + +fn flush() -> Poll<io::Result<Op>> { + Ready(Ok(Flush)) +} diff --git a/third_party/rust/tokio-util/tests/mpsc.rs b/third_party/rust/tokio-util/tests/mpsc.rs new file mode 100644 index 0000000000..a3c164d3ec --- /dev/null +++ b/third_party/rust/tokio-util/tests/mpsc.rs @@ -0,0 +1,239 @@ +use futures::future::poll_fn; +use tokio::sync::mpsc::channel; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; +use tokio_util::sync::PollSender; + +#[tokio::test] +async fn simple() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + drop(recv); + + send.send_item(42).unwrap(); +} + +#[tokio::test] +async fn repeated_poll_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(recv.recv().await.unwrap(), 1); +} + +#[tokio::test] +async fn abort_send() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + let mut send2_send = spawn(send2.send(5)); + assert_pending!(send2_send.poll()); + assert!(send.abort_send()); + assert!(send2_send.is_woken()); + assert_ready_ok!(send2_send.poll()); + + assert_eq!(recv.recv().await.unwrap(), 2); + assert_eq!(recv.recv().await.unwrap(), 3); + assert_eq!(recv.recv().await.unwrap(), 5); +} + +#[tokio::test] +async fn close_sender_last() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_not_last() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + drop(send2); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_before_reserve() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_pending_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert!(recv_task.is_woken()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + drop(reserve); + + send.close(); + + assert!(send.is_closed()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_successful_reserve() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + drop(reserve); + + send.close(); + assert!(send.is_closed()); + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); +} + +#[tokio::test] +async fn abort_send_after_pending_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(send.get_ref().unwrap().capacity(), 0); + assert!(!send.abort_send()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); +} + +#[tokio::test] +async fn abort_send_after_successful_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + assert_eq!(send.get_ref().unwrap().capacity(), 1); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 1); +} + +#[tokio::test] +async fn closed_when_receiver_drops() { + let (send, _) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[should_panic] +#[test] +fn start_send_panics_when_idle() { + let (send, _) = channel::<i32>(3); + let mut send = PollSender::new(send); + + send.send_item(1).unwrap(); +} + +#[should_panic] +#[test] +fn start_send_panics_when_acquiring() { + let (send, _) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + send.send_item(2).unwrap(); +} diff --git a/third_party/rust/tokio-util/tests/poll_semaphore.rs b/third_party/rust/tokio-util/tests/poll_semaphore.rs new file mode 100644 index 0000000000..50f36dd803 --- /dev/null +++ b/third_party/rust/tokio-util/tests/poll_semaphore.rs @@ -0,0 +1,36 @@ +use std::future::Future; +use std::sync::Arc; +use std::task::Poll; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; + +type SemRet = Option<OwnedSemaphorePermit>; + +fn semaphore_poll( + sem: &mut PollSemaphore, +) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); + tokio_test::task::spawn(fut) +} + +#[tokio::test] +async fn it_works() { + let sem = Arc::new(Semaphore::new(1)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit = sem.acquire().await.unwrap(); + let mut poll = semaphore_poll(&mut poll_sem); + assert!(poll.poll().is_pending()); + drop(permit); + + assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); + drop(poll); + + sem.close(); + + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + + // Check that it is fused. + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + assert!(semaphore_poll(&mut poll_sem).await.is_none()); +} diff --git a/third_party/rust/tokio-util/tests/reusable_box.rs b/third_party/rust/tokio-util/tests/reusable_box.rs new file mode 100644 index 0000000000..c8f6da02ae --- /dev/null +++ b/third_party/rust/tokio-util/tests/reusable_box.rs @@ -0,0 +1,72 @@ +use futures::future::FutureExt; +use std::alloc::Layout; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_util::sync::ReusableBoxFuture; + +#[test] +fn test_different_futures() { + let fut = async move { 10 }; + // Not zero sized! + assert_eq!(Layout::for_value(&fut).size(), 1); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(10)); + + b.try_set(async move { 20 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(20)); + + b.try_set(async move { 30 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(30)); +} + +#[test] +fn test_different_sizes() { + let fut1 = async move { 10 }; + let val = [0u32; 1000]; + let fut2 = async move { val[0] }; + let fut3 = ZeroSizedFuture {}; + + assert_eq!(Layout::for_value(&fut1).size(), 1); + assert_eq!(Layout::for_value(&fut2).size(), 4004); + assert_eq!(Layout::for_value(&fut3).size(), 0); + + let mut b = ReusableBoxFuture::new(fut1); + assert_eq!(b.get_pin().now_or_never(), Some(10)); + b.set(fut2); + assert_eq!(b.get_pin().now_or_never(), Some(0)); + b.set(fut3); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} + +struct ZeroSizedFuture {} +impl Future for ZeroSizedFuture { + type Output = u32; + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u32> { + Poll::Ready(5) + } +} + +#[test] +fn test_zero_sized() { + let fut = ZeroSizedFuture {}; + // Zero sized! + assert_eq!(Layout::for_value(&fut).size(), 0); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + + b.try_set(ZeroSizedFuture {}) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} diff --git a/third_party/rust/tokio-util/tests/spawn_pinned.rs b/third_party/rust/tokio-util/tests/spawn_pinned.rs new file mode 100644 index 0000000000..409b8dadab --- /dev/null +++ b/third_party/rust/tokio-util/tests/spawn_pinned.rs @@ -0,0 +1,193 @@ +#![warn(rust_2018_idioms)] + +use std::rc::Rc; +use std::sync::Arc; +use tokio_util::task; + +/// Simple test of running a !Send future via spawn_pinned +#[tokio::test] +async fn can_spawn_not_send_future() { + let pool = task::LocalPoolHandle::new(1); + + let output = pool + .spawn_pinned(|| { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { local_data.to_string() } + }) + .await + .unwrap(); + + assert_eq!(output, "test"); +} + +/// Dropping the join handle still lets the task execute +#[test] +fn can_drop_future_and_still_get_output() { + let pool = task::LocalPoolHandle::new(1); + let (sender, receiver) = std::sync::mpsc::channel(); + + let _ = pool.spawn_pinned(move || { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { + let _ = sender.send(local_data.to_string()); + } + }); + + assert_eq!(receiver.recv(), Ok("test".to_string())); +} + +#[test] +#[should_panic(expected = "assertion failed: pool_size > 0")] +fn cannot_create_zero_sized_pool() { + let _pool = task::LocalPoolHandle::new(0); +} + +/// We should be able to spawn multiple futures onto the pool at the same time. +#[tokio::test] +async fn can_spawn_multiple_futures() { + let pool = task::LocalPoolHandle::new(2); + + let join_handle1 = pool.spawn_pinned(|| { + let local_data = Rc::new("test1"); + async move { local_data.to_string() } + }); + let join_handle2 = pool.spawn_pinned(|| { + let local_data = Rc::new("test2"); + async move { local_data.to_string() } + }); + + assert_eq!(join_handle1.await.unwrap(), "test1"); + assert_eq!(join_handle2.await.unwrap(), "test2"); +} + +/// A panic in the spawned task causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn task_panic_propagates() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| async { + panic!("Test panic"); + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" task still works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// A panic during task creation causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn callback_panic_does_not_kill_worker() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| { + panic!("Test panic"); + #[allow(unreachable_code)] + async {} + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" callback works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// Canceling the task via the returned join handle cancels the spawned task +/// (which has a different, internal join handle). +#[tokio::test] +async fn task_cancellation_propagates() { + let pool = task::LocalPoolHandle::new(1); + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); + let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); + let join_handle = pool.spawn_pinned(|| async move { + let _drop_sender = drop_sender; + // Move the Arc into the task + let _notify_dropped = notify_dropped; + let _ = start_sender.send(()); + + // Keep the task running until it gets aborted + futures::future::pending::<()>().await; + }); + + // Wait for the task to start + let _ = start_receiver.await; + + join_handle.abort(); + + // Wait for the inner task to abort, dropping the sender. + // The top level join handle aborts quicker than the inner task (the abort + // needs to propagate and get processed on the worker thread), so we can't + // just await the top level join handle. + let _ = drop_receiver.await; + + // Check that the Arc has been dropped. This verifies that the inner task + // was canceled as well. + assert!(weak_notify_dropped.upgrade().is_none()); +} + +/// Tasks should be given to the least burdened worker. When spawning two tasks +/// on a pool with two empty workers the tasks should be spawned on separate +/// workers. +#[tokio::test] +async fn tasks_are_balanced() { + let pool = task::LocalPoolHandle::new(2); + + // Spawn a task so one thread has a task count of 1 + let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); + let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); + let join_handle1 = pool.spawn_pinned(|| async move { + let _ = start_sender1.send(()); + let _ = end_receiver1.await; + std::thread::current().id() + }); + + // Wait for the first task to start up + let _ = start_receiver1.await; + + // This task should be spawned on the other thread + let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); + let join_handle2 = pool.spawn_pinned(|| async move { + let _ = start_sender2.send(()); + std::thread::current().id() + }); + + // Wait for the second task to start up + let _ = start_receiver2.await; + + // Allow the first task to end + let _ = end_sender1.send(()); + + let thread_id1 = join_handle1.await.unwrap(); + let thread_id2 = join_handle2.await.unwrap(); + + // Since the first task was active when the second task spawned, they should + // be on separate workers/threads. + assert_ne!(thread_id1, thread_id2); +} diff --git a/third_party/rust/tokio-util/tests/sync_cancellation_token.rs b/third_party/rust/tokio-util/tests/sync_cancellation_token.rs new file mode 100644 index 0000000000..28ba284b6c --- /dev/null +++ b/third_party/rust/tokio-util/tests/sync_cancellation_token.rs @@ -0,0 +1,400 @@ +#![warn(rust_2018_idioms)] + +use tokio::pin; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[test] +fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_grandchild_token_through_parent_if_child_was_dropped() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let intermediate_token = token.child_token(); + let child_token = intermediate_token.child_token(); + drop(intermediate_token); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert!(!token.is_cancelled()); + assert!(child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert!(token.is_cancelled()); + assert!(child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().cloned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + drop(child_fut); + drop(parent_fut); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[test] +fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for child in &mut child_tokens { + *child = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[test] +fn cancel_only_all_descendants() { + // ARRANGE + let (waker, wake_counter) = new_count_waker(); + + let parent_token = CancellationToken::new(); + let token = parent_token.child_token(); + let sibling_token = parent_token.child_token(); + let child1_token = token.child_token(); + let child2_token = token.child_token(); + let grandchild_token = child1_token.child_token(); + let grandchild2_token = child1_token.child_token(); + let grandgrandchild_token = grandchild_token.child_token(); + + assert!(!parent_token.is_cancelled()); + assert!(!token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(!child1_token.is_cancelled()); + assert!(!child2_token.is_cancelled()); + assert!(!grandchild_token.is_cancelled()); + assert!(!grandchild2_token.is_cancelled()); + assert!(!grandgrandchild_token.is_cancelled()); + + let parent_fut = parent_token.cancelled(); + let fut = token.cancelled(); + let sibling_fut = sibling_token.cancelled(); + let child1_fut = child1_token.cancelled(); + let child2_fut = child2_token.cancelled(); + let grandchild_fut = grandchild_token.cancelled(); + let grandchild2_fut = grandchild2_token.cancelled(); + let grandgrandchild_fut = grandgrandchild_token.cancelled(); + + pin!(parent_fut); + pin!(fut); + pin!(sibling_fut); + pin!(child1_fut); + pin!(child2_fut); + pin!(grandchild_fut); + pin!(grandchild2_fut); + pin!(grandgrandchild_fut); + + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + sibling_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandgrandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + // ACT + token.cancel(); + + // ASSERT + assert_eq!(wake_counter, 6); + assert!(!parent_token.is_cancelled()); + assert!(token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(child1_token.is_cancelled()); + assert!(child2_token.is_cancelled()); + assert!(grandchild_token.is_cancelled()); + assert!(grandchild2_token.is_cancelled()); + assert!(grandgrandchild_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandgrandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 6); +} + +#[test] +fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} + +#[test] +fn derives_send_sync() { + fn assert_send<T: Send>() {} + fn assert_sync<T: Sync>() {} + + assert_send::<CancellationToken>(); + assert_sync::<CancellationToken>(); + + assert_send::<WaitForCancellationFuture<'static>>(); + assert_sync::<WaitForCancellationFuture<'static>>(); +} diff --git a/third_party/rust/tokio-util/tests/time_delay_queue.rs b/third_party/rust/tokio-util/tests/time_delay_queue.rs new file mode 100644 index 0000000000..cb163adf3a --- /dev/null +++ b/third_party/rust/tokio-util/tests/time_delay_queue.rs @@ -0,0 +1,818 @@ +#![allow(clippy::blacklisted_name)] +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::time::{self, sleep, sleep_until, Duration, Instant}; +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::time::DelayQueue; + +macro_rules! poll { + ($queue:ident) => { + $queue.enter(|cx, mut queue| queue.poll_expired(cx)) + }; +} + +macro_rules! assert_ready_some { + ($e:expr) => {{ + match assert_ready!($e) { + Some(v) => v, + None => panic!("None"), + } + }}; +} + +#[tokio::test] +async fn single_immediate_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now()); + + // Advance time by 1ms to handle thee rounding + sleep(ms(1)).await; + + assert_ready_some!(poll!(queue)); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn multi_immediate_delays() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let _k = queue.insert_at("1", Instant::now()); + let _k = queue.insert_at("2", Instant::now()); + let _k = queue.insert_at("3", Instant::now()); + + sleep(ms(1)).await; + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn single_short_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(!queue.is_woken()); + + sleep(ms(5)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn multi_delay_at_start() { + time::pause(); + + let long = 262_144 + 9 * 4096; + let delays = &[1000, 2, 234, long, 60, 10]; + + let mut queue = task::spawn(DelayQueue::new()); + + // Setup the delays + for &i in delays { + let _key = queue.insert_at(i, Instant::now() + ms(i)); + } + + assert_pending!(poll!(queue)); + assert!(!queue.is_woken()); + + let start = Instant::now(); + for elapsed in 0..1200 { + println!("elapsed: {:?}", elapsed); + let elapsed = elapsed + 1; + tokio::time::sleep_until(start + ms(elapsed)).await; + + if delays.contains(&elapsed) { + assert!(queue.is_woken()); + assert_ready!(poll!(queue)); + assert_pending!(poll!(queue)); + } else if queue.is_woken() { + let cascade = &[192, 960]; + assert!( + cascade.contains(&elapsed), + "elapsed={} dt={:?}", + elapsed, + Instant::now() - start + ); + + assert_pending!(poll!(queue)); + } + } + println!("finished multi_delay_start"); +} + +#[tokio::test] +async fn insert_in_past_fires_immediately() { + println!("running insert_in_past_fires_immediately"); + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(10)).await; + + queue.insert_at("foo", now); + + assert_ready!(poll!(queue)); + println!("finished insert_in_past_fires_immediately"); +} + +#[tokio::test] +async fn remove_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); + + sleep(ms(10)).await; + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn reset_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + let key = queue.insert_at("foo", now + ms(5)); + + assert_pending!(poll!(queue)); + sleep(ms(1)).await; + + queue.reset_at(&key, now + ms(10)); + + assert_pending!(poll!(queue)); + + sleep(ms(7)).await; + + assert!(!queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_much_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(10)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_twice() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(50)); + + sleep(ms(20)).await; + + queue.reset_at(&key, now + ms(40)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +/// Regression test: Given an entry inserted with a deadline in the past, so +/// that it is placed directly on the expired queue, reset the entry to a +/// deadline in the future. Validate that this leaves the entry and queue in an +/// internally consistent state by running an additional reset on the entry +/// before polling it to completion. +#[tokio::test] +async fn repeatedly_reset_entry_inserted_as_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + queue.reset_at(&key, now + ms(100)); + queue.reset_at(&key, now + ms(50)); + + assert_pending!(poll!(queue)); + + time::sleep_until(now + ms(60)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn remove_expired_item() { + time::pause(); + + let mut queue = DelayQueue::new(); + + let now = Instant::now(); + + sleep(ms(10)).await; + + let key = queue.insert_at("foo", now); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); +} + +/// Regression test: it should be possible to remove entries which fall in the +/// 0th slot of the internal timer wheel — that is, entries whose expiration +/// (a) falls at the beginning of one of the wheel's hierarchical levels and (b) +/// is equal to the wheel's current elapsed time. +#[tokio::test] +async fn remove_at_timer_wheel_threshold() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo", now + ms(64)); + let key2 = queue.insert_at("bar", now + ms(64)); + + sleep(ms(80)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + + match entry { + "foo" => { + let entry = queue.remove(&key2).into_inner(); + assert_eq!(entry, "bar"); + } + "bar" => { + let entry = queue.remove(&key1).into_inner(); + assert_eq!(entry, "foo"); + } + other => panic!("other: {:?}", other), + } +} + +#[tokio::test] +async fn expires_before_last_insert() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(10_000)); + + // Delay should be set to 8.192s here. + assert_pending!(poll!(queue)); + + // Delay should be set to the delay of the new item here + queue.insert_at("bar", now + ms(600)); + + assert_pending!(poll!(queue)); + + sleep(ms(600)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn multi_reset() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + queue.reset_at(&two, now + ms(350)); + queue.reset_at(&one, now + ms(400)); + + sleep(ms(310)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "two"); + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "one"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn expire_first_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "one"); +} + +#[tokio::test] +async fn expire_second_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&two, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn reset_first_expiring_item_to_expire_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let _two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + sleep(ms(250)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_before_first_after_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let _one = queue.insert_at("one", now + ms(200)); + + assert_pending!(poll!(queue)); + + let _two = queue.insert_at("two", now + ms(100)); + + sleep(ms(99)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_after_ready_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("1", now + ms(100)); + queue.insert_at("2", now + ms(100)); + queue.insert_at("3", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + queue.insert_at("foo", now + ms(500)); + } + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn reset_later_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn reset_inserted_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + // this causes the panic described in #2473 + queue.reset_at(&key, now + ms(100)); + + assert_eq!(1, queue.len()); + + sleep(ms(200)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + assert_eq!(queue.len(), 0); +} + +#[tokio::test] +async fn reset_earlier_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn insert_in_past_after_poll_fires_immediately() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(80)).await; + + assert!(!queue.is_woken()); + queue.insert_at("bar", now + ms(40)); + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn delay_queue_poll_expired_when_empty() { + let mut delay_queue = task::spawn(DelayQueue::new()); + let key = delay_queue.insert(0, std::time::Duration::from_secs(10)); + assert_pending!(poll!(delay_queue)); + + delay_queue.remove(&key); + assert!(assert_ready!(poll!(delay_queue)).is_none()); +} + +#[tokio::test(start_paused = true)] +async fn compact_expire_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +async fn compact_remove_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo1", now + ms(10)); + let key2 = queue.insert_at("foo2", now + ms(10)); + + queue.remove(&key1); + queue.remove(&key2); + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +// Trigger a re-mapping of keys in the slab due to a `compact` call and +// test removal of re-mapped keys +async fn compact_remove_remapped_keys() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + let key3 = queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` will be assigned + // new indices here + queue.compact(); + + queue.insert_at("foo5", now + ms(10)); + + // test removal of re-mapped keys + let expired3 = queue.remove(&key3); + let expired4 = queue.remove(&key4); + + assert_eq!(expired3.into_inner(), "foo3"); + assert_eq!(expired4.into_inner(), "foo4"); + + queue.compact(); + assert_eq!(queue.len(), 1); + assert_eq!(queue.capacity(), 1); +} + +#[tokio::test(start_paused = true)] +async fn compact_change_deadline() { + let mut queue = task::spawn(DelayQueue::new()); + + let mut now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` should be assigned + // new indices + queue.compact(); + + now = Instant::now(); + + queue.insert_at("foo5", now + ms(10)); + let key6 = queue.insert_at("foo6", now + ms(10)); + + queue.reset_at(&key4, now + ms(20)); + queue.reset_at(&key6, now + ms(20)); + + // foo3 and foo5 will expire + sleep(ms(10)).await; + + while res.len() < 4 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + sleep(ms(10)).await; + + while res.len() < 6 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test(start_paused = true)] +async fn remove_after_compact() { + let now = Instant::now(); + let mut queue = DelayQueue::new(); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + queue.remove(&foo_key); + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +#[tokio::test(start_paused = true)] +async fn remove_after_compact_poll() { + let now = Instant::now(); + let mut queue = task::spawn(DelayQueue::new()); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + + sleep(ms(10)).await; + assert_eq!(assert_ready_some!(poll!(queue)).key(), foo_key); + + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} diff --git a/third_party/rust/tokio-util/tests/udp.rs b/third_party/rust/tokio-util/tests/udp.rs new file mode 100644 index 0000000000..b9436a30aa --- /dev/null +++ b/third_party/rust/tokio-util/tests/udp.rs @@ -0,0 +1,132 @@ +#![warn(rust_2018_idioms)] + +use tokio::net::UdpSocket; +use tokio_stream::StreamExt; +use tokio_util::codec::{Decoder, Encoder, LinesCodec}; +use tokio_util::udp::UdpFramed; + +use bytes::{BufMut, BytesMut}; +use futures::future::try_join; +use futures::future::FutureExt; +use futures::sink::SinkExt; +use std::io; +use std::sync::Arc; + +#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] +#[tokio::test] +async fn send_framed_byte_codec() -> std::io::Result<()> { + let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + // test sending & receiving bytes + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b"4567"; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + + a_soc = a.into_inner(); + b_soc = b.into_inner(); + } + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + // test sending & receiving an empty message + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b""; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + } + + Ok(()) +} + +pub struct ByteCodec; + +impl Decoder for ByteCodec { + type Item = Vec<u8>; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Vec<u8>>, io::Error> { + let len = buf.len(); + Ok(Some(buf.split_to(len).to_vec())) + } +} + +impl Encoder<&[u8]> for ByteCodec { + type Error = io::Error; + + fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put_slice(data); + Ok(()) + } +} + +#[tokio::test] +async fn send_framed_lines_codec() -> std::io::Result<()> { + let a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + Ok(()) +} + +#[tokio::test] +async fn framed_half() -> std::io::Result<()> { + let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let b_soc = a_soc.clone(); + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + let msg = b"4\r\n5\r\n6\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); + + Ok(()) +} |