diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/rust/tokio-util/tests | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/tokio-util/tests')
18 files changed, 4042 insertions, 0 deletions
diff --git a/third_party/rust/tokio-util/tests/_require_full.rs b/third_party/rust/tokio-util/tests/_require_full.rs new file mode 100644 index 0000000000..045934d175 --- /dev/null +++ b/third_party/rust/tokio-util/tests/_require_full.rs @@ -0,0 +1,2 @@ +#![cfg(not(feature = "full"))] +compile_error!("run tokio-util tests with `--features full`"); diff --git a/third_party/rust/tokio-util/tests/codecs.rs b/third_party/rust/tokio-util/tests/codecs.rs new file mode 100644 index 0000000000..f9a780140a --- /dev/null +++ b/third_party/rust/tokio-util/tests/codecs.rs @@ -0,0 +1,464 @@ +#![warn(rust_2018_idioms)] + +use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec}; + +use bytes::{BufMut, Bytes, BytesMut}; + +#[test] +fn bytes_decoder() { + let mut codec = BytesCodec::new(); + let buf = &mut BytesMut::new(); + buf.put_slice(b"abc"); + assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"a"); + assert_eq!("a", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn bytes_encoder() { + let mut codec = BytesCodec::new(); + + // Default capacity of BytesMut + #[cfg(target_pointer_width = "64")] + const INLINE_CAP: usize = 4 * 8 - 1; + #[cfg(target_pointer_width = "32")] + const INLINE_CAP: usize = 4 * 4 - 1; + + let mut buf = BytesMut::new(); + codec + .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) + .unwrap(); + + // Default capacity of Framed Read + const INITIAL_CAPACITY: usize = 8 * 1024; + + let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); + codec + .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) + .unwrap(); + codec + .encode(BytesMut::from(&b"hello"[..]), &mut buf) + .unwrap(); +} + +#[test] +fn lines_decoder() { + let mut codec = LinesCodec::new(); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); + assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn lines_decoder_max_length() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 2", line); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 4", line); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let line = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("\rk", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Line that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"line 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"\n"); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_newline_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"\nworld"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +// Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) +#[test] +fn lines_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +// Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555) +#[test] +fn lines_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_encoder() { + let mut codec = LinesCodec::new(); + let mut buf = BytesMut::new(); + + codec.encode("line 1", &mut buf).unwrap(); + assert_eq!("line 1\n", buf); + + codec.encode("line 2", &mut buf).unwrap(); + assert_eq!("line 1\nline 2\n", buf); +} + +#[test] +fn any_delimiters_decoder_any_character() { + let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec()); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); + assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn any_delimiters_decoder_max_length() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n"); + + assert!(codec.decode(buf).is_err()); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 2", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 3", chunk); + + // \r\n cause empty chunk + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 4", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let chunk = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("k", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Delimiter that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbcc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"chunk 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b","); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} +#[test] +fn any_delimiter_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_delimiter_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b",world"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +#[test] +fn any_delimiter_encoder() { + let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec()); + let mut buf = BytesMut::new(); + + codec.encode("chunk 1", &mut buf).unwrap(); + assert_eq!("chunk 1;--;", buf); + + codec.encode("chunk 2", &mut buf).unwrap(); + assert_eq!("chunk 1;--;chunk 2;--;", buf); +} diff --git a/third_party/rust/tokio-util/tests/context.rs b/third_party/rust/tokio-util/tests/context.rs new file mode 100644 index 0000000000..7510f36fd1 --- /dev/null +++ b/third_party/rust/tokio-util/tests/context.rs @@ -0,0 +1,24 @@ +#![cfg(feature = "rt")] +#![warn(rust_2018_idioms)] + +use tokio::runtime::Builder; +use tokio::time::*; +use tokio_util::context::RuntimeExt; + +#[test] +fn tokio_context_with_another_runtime() { + let rt1 = Builder::new_multi_thread() + .worker_threads(1) + // no timer! + .build() + .unwrap(); + let rt2 = Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + // Without the `HandleExt.wrap()` there would be a panic because there is + // no timer running, since it would be referencing runtime r1. + let _ = rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await })); +} diff --git a/third_party/rust/tokio-util/tests/framed.rs b/third_party/rust/tokio-util/tests/framed.rs new file mode 100644 index 0000000000..ec8cdf00d0 --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed.rs @@ -0,0 +1,152 @@ +#![warn(rust_2018_idioms)] + +use tokio_stream::StreamExt; +use tokio_test::assert_ok; +use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; + +use bytes::{Buf, BufMut, BytesMut}; +use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Encode and decode u32 values. +#[derive(Default)] +struct U32Codec { + read_bytes: usize, +} + +impl Decoder for U32Codec { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + self.read_bytes += 4; + Ok(Some(n)) + } +} + +impl Encoder<u32> for U32Codec { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +/// Encode and decode u64 values. +#[derive(Default)] +struct U64Codec { + read_bytes: usize, +} + +impl Decoder for U64Codec { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + self.read_bytes += 8; + Ok(Some(n)) + } +} + +impl Encoder<u64> for U64Codec { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +/// This value should never be used +struct DontReadIntoThis; + +impl Read for DontReadIntoThis { + fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { + Err(io::Error::new( + io::ErrorKind::Other, + "Read into something you weren't supposed to.", + )) + } +} + +impl tokio::io::AsyncRead for DontReadIntoThis { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unreachable!() + } +} + +#[tokio::test] +async fn can_read_from_existing_buf() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); +} + +#[tokio::test] +async fn can_read_from_existing_buf_after_codec_changed() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); + + let mut framed = framed.map_codec(|codec| U64Codec { + read_bytes: codec.read_bytes, + }); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 84); + assert_eq!(framed.codec().read_bytes, 12); +} + +#[test] +fn external_buf_grows_to_init() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); +} + +#[test] +fn external_buf_does_not_shrink() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); +} diff --git a/third_party/rust/tokio-util/tests/framed_read.rs b/third_party/rust/tokio-util/tests/framed_read.rs new file mode 100644 index 0000000000..2a9e27e22f --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_read.rs @@ -0,0 +1,339 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_test::assert_ready; +use tokio_test::task; +use tokio_util::codec::{Decoder, FramedRead}; + +use bytes::{Buf, BytesMut}; +use futures::Stream; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Decoder; + +impl Decoder for U32Decoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + Ok(Some(n)) + } +} + +struct U64Decoder; + +impl Decoder for U64Decoder { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + Ok(Some(n)) + } +} + +#[test] +fn read_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_across_packets() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_in_packet_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0x04); + + let mut framed = framed.map_decoder(|_| U64Decoder); + assert_read!(pin!(framed).poll_next(cx), 0x08); + + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_partial_then_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_err() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_would_block_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn huge_size() { + let mut task = task::spawn(()); + let data = &[0; 32 * 1024][..]; + let mut framed = FramedRead::new(data, BigDecoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); + + struct BigDecoder; + + impl Decoder for BigDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> { + if buf.len() < 32 * 1024 { + return Ok(None); + } + buf.advance(32 * 1024); + Ok(Some(0)) + } + } +} + +#[test] +fn data_remaining_is_error() { + let mut task = task::spawn(()); + let slice = &[0; 5][..]; + let mut framed = FramedRead::new(slice, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); + }); +} + +#[test] +fn multi_frames_on_eof() { + let mut task = task::spawn(()); + struct MyDecoder(Vec<u32>); + + impl Decoder for MyDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + unreachable!(); + } + + fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> { + if self.0.is_empty() { + return Ok(None); + } + + Ok(Some(self.0.remove(0))) + } + } + + let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_eof_then_resume() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x03".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +// ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + use io::ErrorKind::WouldBlock; + + match self.calls.pop_front() { + Some(Ok(data)) => { + debug_assert!(buf.remaining() >= data.len()); + buf.put_slice(&data); + Ready(Ok(())) + } + Some(Err(ref e)) if e.kind() == WouldBlock => Pending, + Some(Err(e)) => Ready(Err(e)), + None => Ready(Ok(())), + } + } +} diff --git a/third_party/rust/tokio-util/tests/framed_stream.rs b/third_party/rust/tokio-util/tests/framed_stream.rs new file mode 100644 index 0000000000..76d8af7b7d --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_stream.rs @@ -0,0 +1,38 @@ +use futures_core::stream::Stream; +use std::{io, pin::Pin}; +use tokio_test::{assert_ready, io::Builder, task}; +use tokio_util::codec::{BytesCodec, FramedRead}; + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +#[tokio::test] +async fn return_none_after_error() { + let mut io = FramedRead::new( + Builder::new() + .read(b"abcdef") + .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) + .read(b"more data") + .build(), + BytesCodec::new(), + ); + + let mut task = task::spawn(()); + + task.enter(|cx, _| { + assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); + assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); + assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); + assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); + }) +} diff --git a/third_party/rust/tokio-util/tests/framed_write.rs b/third_party/rust/tokio-util/tests/framed_write.rs new file mode 100644 index 0000000000..259d9b0c9f --- /dev/null +++ b/third_party/rust/tokio-util/tests/framed_write.rs @@ -0,0 +1,211 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::AsyncWrite; +use tokio_test::{assert_ready, task}; +use tokio_util::codec::{Encoder, FramedWrite}; + +use bytes::{BufMut, BytesMut}; +use futures_sink::Sink; +use std::collections::VecDeque; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Encoder; + +impl Encoder<u32> for U32Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +struct U64Encoder; + +impl Encoder<u64> for U64Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +#[test] +fn write_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(1).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(2).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_multi_frame_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x04).is_ok()); + + let mut framed = framed.map_encoder(|_| U64Encoder); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x08).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_hits_backpressure() { + const ITER: usize = 2 * 1024; + + let mut mock = mock! { + // Block the `ITER`th write + Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), + Ok(b"".to_vec()), + }; + + for i in 0..=ITER { + let mut b = BytesMut::with_capacity(4); + b.put_u32(i as u32); + + // Append to the end + match mock.calls.back_mut().unwrap() { + Ok(ref mut data) => { + // Write in 2kb chunks + if data.len() < ITER { + data.extend_from_slice(&b[..]); + continue; + } // else fall through and create a new buffer + } + _ => unreachable!(), + } + + // Push a new new chunk + mock.calls.push_back(Ok(b[..].to_vec())); + } + // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer + assert_eq!(mock.calls.len(), 6); + + let mut task = task::spawn(()); + let mut framed = FramedWrite::new(mock, U32Encoder); + task.enter(|cx, _| { + // Send 8KB. This fills up FramedWrite2 buffer + for i in 0..ITER { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(i as u32).is_ok()); + } + + // Now we poll_ready which forces a flush. The mock pops the front message + // and decides to block. + assert!(pin!(framed).poll_ready(cx).is_pending()); + + // We poll again, forcing another flush, which this time succeeds + // The whole 8KB buffer is flushed + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + + // Send more data. This matches the final message expected by the mock + assert!(pin!(framed).start_send(ITER as u32).is_ok()); + + // Flush the rest of the buffer + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + // Ensure the mock is empty + assert_eq!(0, framed.get_ref().calls.len()); + }) +} + +// // ===== Mock ====== + +struct Mock { + calls: VecDeque<io::Result<Vec<u8>>>, +} + +impl Write for Mock { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + match self.calls.pop_front() { + Some(Ok(data)) => { + assert!(src.len() >= data.len()); + assert_eq!(&data[..], &src[..data.len()]); + Ok(data.len()) + } + Some(Err(e)) => Err(e), + None => panic!("unexpected write; {:?}", src), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match Pin::get_mut(self).write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match Pin::get_mut(self).flush() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + unimplemented!() + } +} diff --git a/third_party/rust/tokio-util/tests/io_reader_stream.rs b/third_party/rust/tokio-util/tests/io_reader_stream.rs new file mode 100644 index 0000000000..e30cd85164 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_reader_stream.rs @@ -0,0 +1,65 @@ +#![warn(rust_2018_idioms)] + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_stream::StreamExt; + +/// produces at most `remaining` zeros, that returns error. +/// each time it reads at most 31 byte. +struct Reader { + remaining: usize, +} + +impl AsyncRead for Reader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + let this = Pin::into_inner(self); + assert_ne!(buf.remaining(), 0); + if this.remaining > 0 { + let n = std::cmp::min(this.remaining, buf.remaining()); + let n = std::cmp::min(n, 31); + for x in &mut buf.initialize_unfilled_to(n)[..n] { + *x = 0; + } + buf.advance(n); + this.remaining -= n; + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) + } + } +} + +#[tokio::test] +async fn correct_behavior_on_errors() { + let reader = Reader { remaining: 8000 }; + let mut stream = tokio_util::io::ReaderStream::new(reader); + let mut zeros_received = 0; + let mut had_error = false; + loop { + let item = stream.next().await.unwrap(); + println!("{:?}", item); + match item { + Ok(bytes) => { + let bytes = &*bytes; + for byte in bytes { + assert_eq!(*byte, 0); + zeros_received += 1; + } + } + Err(_) => { + assert!(!had_error); + had_error = true; + break; + } + } + } + + assert!(had_error); + assert_eq!(zeros_received, 8000); + assert!(stream.next().await.is_none()); +} diff --git a/third_party/rust/tokio-util/tests/io_stream_reader.rs b/third_party/rust/tokio-util/tests/io_stream_reader.rs new file mode 100644 index 0000000000..59759941c5 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_stream_reader.rs @@ -0,0 +1,35 @@ +#![warn(rust_2018_idioms)] + +use bytes::Bytes; +use tokio::io::AsyncReadExt; +use tokio_stream::iter; +use tokio_util::io::StreamReader; + +#[tokio::test] +async fn test_stream_reader() -> std::io::Result<()> { + let stream = iter(vec![ + std::io::Result::Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[0, 1, 2, 3])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[4, 5, 6, 7])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[8, 9, 10, 11])), + Ok(Bytes::from_static(&[])), + ]); + + let mut read = StreamReader::new(stream); + + let mut buf = [0; 5]; + read.read_exact(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + assert_eq!(read.read(&mut buf).await?, 3); + assert_eq!(&buf[..3], [5, 6, 7]); + + assert_eq!(read.read(&mut buf).await?, 4); + assert_eq!(&buf[..4], [8, 9, 10, 11]); + + assert_eq!(read.read(&mut buf).await?, 0); + + Ok(()) +} diff --git a/third_party/rust/tokio-util/tests/io_sync_bridge.rs b/third_party/rust/tokio-util/tests/io_sync_bridge.rs new file mode 100644 index 0000000000..0d420857b5 --- /dev/null +++ b/third_party/rust/tokio-util/tests/io_sync_bridge.rs @@ -0,0 +1,43 @@ +#![cfg(feature = "io-util")] + +use std::error::Error; +use std::io::{Cursor, Read, Result as IoResult}; +use tokio::io::AsyncRead; +use tokio_util::io::SyncIoBridge; + +async fn test_reader_len( + r: impl AsyncRead + Unpin + Send + 'static, + expected_len: usize, +) -> IoResult<()> { + let mut r = SyncIoBridge::new(r); + let res = tokio::task::spawn_blocking(move || { + let mut buf = Vec::new(); + r.read_to_end(&mut buf)?; + Ok::<_, std::io::Error>(buf) + }) + .await?; + assert_eq!(res?.len(), expected_len); + Ok(()) +} + +#[tokio::test] +async fn test_async_read_to_sync() -> Result<(), Box<dyn Error>> { + test_reader_len(tokio::io::empty(), 0).await?; + let buf = b"hello world"; + test_reader_len(Cursor::new(buf), buf.len()).await?; + Ok(()) +} + +#[tokio::test] +async fn test_async_write_to_sync() -> Result<(), Box<dyn Error>> { + let mut dest = Vec::new(); + let src = b"hello world"; + let dest = tokio::task::spawn_blocking(move || -> Result<_, String> { + let mut w = SyncIoBridge::new(Cursor::new(&mut dest)); + std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?; + Ok(dest) + }) + .await??; + assert_eq!(dest.as_slice(), src); + Ok(()) +} diff --git a/third_party/rust/tokio-util/tests/length_delimited.rs b/third_party/rust/tokio-util/tests/length_delimited.rs new file mode 100644 index 0000000000..126e41b5cd --- /dev/null +++ b/third_party/rust/tokio-util/tests/length_delimited.rs @@ -0,0 +1,779 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_test::task; +use tokio_test::{ + assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, +}; +use tokio_util::codec::*; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{pin_mut, Sink, Stream}; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::*; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_next_eq { + ($io:ident, $expect:expr) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), + Some(Err(e)) => panic!("error = {:?}", e), + None => panic!("none"), + } + }); + }}; +} + +macro_rules! assert_next_pending { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(e))) => panic!("error = {:?}", e), + Ready(None) => panic!("done"), + Pending => {} + }); + }}; +} + +macro_rules! assert_next_err { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(_))) => {} + Ready(None) => panic!("done"), + Pending => panic!("pending"), + }); + }}; +} + +macro_rules! assert_done { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => panic!("value = {:?}", v), + Some(Err(e)) => panic!("error = {:?}", e), + None => {} + } + }); + }}; +} + +#[test] +fn read_empty_io_yields_nothing() { + let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); + pin_mut!(io); + + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_read(mock! { + data(b"\x09\x00\x00\x00abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_native_endian() { + let d = if cfg!(target_endian = "big") { + b"\x00\x00\x00\x09abcdefghi" + } else { + b"\x09\x00\x00\x00abcdefghi" + }; + let io = length_delimited::Builder::new() + .native_endian() + .new_read(mock! { + data(d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); + d.extend_from_slice(b"\x00\x00\x00\x03123"); + d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); + + let io = FramedRead::new( + mock! { + data(&d), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + data(b"\x00\x00\x00\x0312"), + data(b"3\x00\x00\x00\x0bhello world"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + data(b"\x00\x00\x00\x0312"), + Pending, + data(b"3\x00\x00\x00\x0bhello world"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_incomplete_head() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_incomplete_head_multi() { + let io = FramedRead::new( + mock! { + Pending, + data(b"\x00"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_incomplete_payload() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09ab"), + Pending, + data(b"cd"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + io.decoder_mut().set_max_frame_length(5); + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcd"), + Pending, + data(b"efghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_pending!(io); + io.decoder_mut().set_max_frame_length(5); + assert_next_eq!(io, b"abcdefghi"); + assert_next_err!(io); +} + +#[test] +fn read_one_byte_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_read(mock! { + data(b"\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_header_offset() { + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(4) + .new_read(mock! { + data(b"zzzz\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_skip_none_adjusted() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"xx\x00\x09abcdefghi"); + d.extend_from_slice(b"yy\x00\x03123"); + d.extend_from_slice(b"zz\x00\x0bhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(2) + .num_skip(0) + .length_adjustment(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"xx\x00\x09abcdefghi"); + assert_next_eq!(io, b"yy\x00\x03123"); + assert_next_eq!(io, b"zz\x00\x0bhello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_length_adjusted() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x00\x0b\x0cHello world"); + + let io = length_delimited::Builder::new() + .length_field_offset(0) + .length_field_length(3) + .length_adjustment(0) + .num_skip(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"Hello world"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_length_includes_head() { + let mut d: Vec<u8> = vec![]; + d.extend_from_slice(b"\x00\x0babcdefghi"); + d.extend_from_slice(b"\x00\x05123"); + d.extend_from_slice(b"\x00\x0dhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_adjustment(-2) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn write_single_frame_length_adjusted() { + let io = length_delimited::Builder::new() + .length_adjustment(-2) + .new_write(mock! { + data(b"\x00\x00\x00\x0b"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_nothing_yields_nothing() { + let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.poll_flush(cx)); + }); +} + +#[test] +fn write_single_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + data(b"\x00\x00\x00\x03"), + data(b"123"), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_multi_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + data(b"\x00\x00\x00\x03"), + data(b"123"), + flush(), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_would_block() { + let io = FramedWrite::new( + mock! { + Pending, + data(b"\x00\x00"), + Pending, + data(b"\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + assert_pending!(io.as_mut().poll_flush(cx)); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_write(mock! { + data(b"\x09\x00\x00\x00"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_with_short_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_write(mock! { + data(b"\x09"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"abcdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"ab"), + Pending, + data(b"cdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_zero() { + let io = length_delimited::Builder::new().new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_err!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn encode_overflow() { + // Test reproducing tokio-rs/tokio#681. + let mut codec = length_delimited::Builder::new().new_codec(); + let mut buf = BytesMut::with_capacity(1024); + + // Put some data into the buffer without resizing it to hold more. + let some_as = std::iter::repeat(b'a').take(1024).collect::<Vec<_>>(); + buf.put_slice(&some_as[..]); + + // Trying to encode the length header should resize the buffer if it won't fit. + codec.encode(Bytes::from("hello"), &mut buf).unwrap(); +} + +// ===== Test utils ===== + +struct Mock { + calls: VecDeque<Poll<io::Result<Op>>>, +} + +enum Op { + Data(Vec<u8>), + Flush, +} + +use self::Op::*; + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + debug_assert!(dst.remaining() >= data.len()); + dst.put_slice(&data); + Ready(Ok(())) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } +} + +impl AsyncWrite for Mock { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + src: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + let len = data.len(); + assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); + assert_eq!(&data[..], &src[..len]); + Ready(Ok(len)) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(0)), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ready(Ok(())) + } +} + +impl<'a> From<&'a [u8]> for Op { + fn from(src: &'a [u8]) -> Op { + Op::Data(src.into()) + } +} + +impl From<Vec<u8>> for Op { + fn from(src: Vec<u8>) -> Op { + Op::Data(src) + } +} + +fn data(bytes: &[u8]) -> Poll<io::Result<Op>> { + Ready(Ok(bytes.into())) +} + +fn flush() -> Poll<io::Result<Op>> { + Ready(Ok(Flush)) +} diff --git a/third_party/rust/tokio-util/tests/mpsc.rs b/third_party/rust/tokio-util/tests/mpsc.rs new file mode 100644 index 0000000000..a3c164d3ec --- /dev/null +++ b/third_party/rust/tokio-util/tests/mpsc.rs @@ -0,0 +1,239 @@ +use futures::future::poll_fn; +use tokio::sync::mpsc::channel; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; +use tokio_util::sync::PollSender; + +#[tokio::test] +async fn simple() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + drop(recv); + + send.send_item(42).unwrap(); +} + +#[tokio::test] +async fn repeated_poll_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(recv.recv().await.unwrap(), 1); +} + +#[tokio::test] +async fn abort_send() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + let mut send2_send = spawn(send2.send(5)); + assert_pending!(send2_send.poll()); + assert!(send.abort_send()); + assert!(send2_send.is_woken()); + assert_ready_ok!(send2_send.poll()); + + assert_eq!(recv.recv().await.unwrap(), 2); + assert_eq!(recv.recv().await.unwrap(), 3); + assert_eq!(recv.recv().await.unwrap(), 5); +} + +#[tokio::test] +async fn close_sender_last() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_not_last() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + drop(send2); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_before_reserve() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_pending_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert!(recv_task.is_woken()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + drop(reserve); + + send.close(); + + assert!(send.is_closed()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_successful_reserve() { + let (send, mut recv) = channel::<i32>(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + drop(reserve); + + send.close(); + assert!(send.is_closed()); + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); +} + +#[tokio::test] +async fn abort_send_after_pending_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(send.get_ref().unwrap().capacity(), 0); + assert!(!send.abort_send()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); +} + +#[tokio::test] +async fn abort_send_after_successful_reserve() { + let (send, mut recv) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + assert_eq!(send.get_ref().unwrap().capacity(), 1); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 1); +} + +#[tokio::test] +async fn closed_when_receiver_drops() { + let (send, _) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[should_panic] +#[test] +fn start_send_panics_when_idle() { + let (send, _) = channel::<i32>(3); + let mut send = PollSender::new(send); + + send.send_item(1).unwrap(); +} + +#[should_panic] +#[test] +fn start_send_panics_when_acquiring() { + let (send, _) = channel::<i32>(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + send.send_item(2).unwrap(); +} diff --git a/third_party/rust/tokio-util/tests/poll_semaphore.rs b/third_party/rust/tokio-util/tests/poll_semaphore.rs new file mode 100644 index 0000000000..50f36dd803 --- /dev/null +++ b/third_party/rust/tokio-util/tests/poll_semaphore.rs @@ -0,0 +1,36 @@ +use std::future::Future; +use std::sync::Arc; +use std::task::Poll; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; + +type SemRet = Option<OwnedSemaphorePermit>; + +fn semaphore_poll( + sem: &mut PollSemaphore, +) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); + tokio_test::task::spawn(fut) +} + +#[tokio::test] +async fn it_works() { + let sem = Arc::new(Semaphore::new(1)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit = sem.acquire().await.unwrap(); + let mut poll = semaphore_poll(&mut poll_sem); + assert!(poll.poll().is_pending()); + drop(permit); + + assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); + drop(poll); + + sem.close(); + + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + + // Check that it is fused. + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + assert!(semaphore_poll(&mut poll_sem).await.is_none()); +} diff --git a/third_party/rust/tokio-util/tests/reusable_box.rs b/third_party/rust/tokio-util/tests/reusable_box.rs new file mode 100644 index 0000000000..c8f6da02ae --- /dev/null +++ b/third_party/rust/tokio-util/tests/reusable_box.rs @@ -0,0 +1,72 @@ +use futures::future::FutureExt; +use std::alloc::Layout; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_util::sync::ReusableBoxFuture; + +#[test] +fn test_different_futures() { + let fut = async move { 10 }; + // Not zero sized! + assert_eq!(Layout::for_value(&fut).size(), 1); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(10)); + + b.try_set(async move { 20 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(20)); + + b.try_set(async move { 30 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(30)); +} + +#[test] +fn test_different_sizes() { + let fut1 = async move { 10 }; + let val = [0u32; 1000]; + let fut2 = async move { val[0] }; + let fut3 = ZeroSizedFuture {}; + + assert_eq!(Layout::for_value(&fut1).size(), 1); + assert_eq!(Layout::for_value(&fut2).size(), 4004); + assert_eq!(Layout::for_value(&fut3).size(), 0); + + let mut b = ReusableBoxFuture::new(fut1); + assert_eq!(b.get_pin().now_or_never(), Some(10)); + b.set(fut2); + assert_eq!(b.get_pin().now_or_never(), Some(0)); + b.set(fut3); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} + +struct ZeroSizedFuture {} +impl Future for ZeroSizedFuture { + type Output = u32; + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u32> { + Poll::Ready(5) + } +} + +#[test] +fn test_zero_sized() { + let fut = ZeroSizedFuture {}; + // Zero sized! + assert_eq!(Layout::for_value(&fut).size(), 0); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + + b.try_set(ZeroSizedFuture {}) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} diff --git a/third_party/rust/tokio-util/tests/spawn_pinned.rs b/third_party/rust/tokio-util/tests/spawn_pinned.rs new file mode 100644 index 0000000000..409b8dadab --- /dev/null +++ b/third_party/rust/tokio-util/tests/spawn_pinned.rs @@ -0,0 +1,193 @@ +#![warn(rust_2018_idioms)] + +use std::rc::Rc; +use std::sync::Arc; +use tokio_util::task; + +/// Simple test of running a !Send future via spawn_pinned +#[tokio::test] +async fn can_spawn_not_send_future() { + let pool = task::LocalPoolHandle::new(1); + + let output = pool + .spawn_pinned(|| { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { local_data.to_string() } + }) + .await + .unwrap(); + + assert_eq!(output, "test"); +} + +/// Dropping the join handle still lets the task execute +#[test] +fn can_drop_future_and_still_get_output() { + let pool = task::LocalPoolHandle::new(1); + let (sender, receiver) = std::sync::mpsc::channel(); + + let _ = pool.spawn_pinned(move || { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { + let _ = sender.send(local_data.to_string()); + } + }); + + assert_eq!(receiver.recv(), Ok("test".to_string())); +} + +#[test] +#[should_panic(expected = "assertion failed: pool_size > 0")] +fn cannot_create_zero_sized_pool() { + let _pool = task::LocalPoolHandle::new(0); +} + +/// We should be able to spawn multiple futures onto the pool at the same time. +#[tokio::test] +async fn can_spawn_multiple_futures() { + let pool = task::LocalPoolHandle::new(2); + + let join_handle1 = pool.spawn_pinned(|| { + let local_data = Rc::new("test1"); + async move { local_data.to_string() } + }); + let join_handle2 = pool.spawn_pinned(|| { + let local_data = Rc::new("test2"); + async move { local_data.to_string() } + }); + + assert_eq!(join_handle1.await.unwrap(), "test1"); + assert_eq!(join_handle2.await.unwrap(), "test2"); +} + +/// A panic in the spawned task causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn task_panic_propagates() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| async { + panic!("Test panic"); + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" task still works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// A panic during task creation causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn callback_panic_does_not_kill_worker() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| { + panic!("Test panic"); + #[allow(unreachable_code)] + async {} + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" callback works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// Canceling the task via the returned join handle cancels the spawned task +/// (which has a different, internal join handle). +#[tokio::test] +async fn task_cancellation_propagates() { + let pool = task::LocalPoolHandle::new(1); + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); + let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); + let join_handle = pool.spawn_pinned(|| async move { + let _drop_sender = drop_sender; + // Move the Arc into the task + let _notify_dropped = notify_dropped; + let _ = start_sender.send(()); + + // Keep the task running until it gets aborted + futures::future::pending::<()>().await; + }); + + // Wait for the task to start + let _ = start_receiver.await; + + join_handle.abort(); + + // Wait for the inner task to abort, dropping the sender. + // The top level join handle aborts quicker than the inner task (the abort + // needs to propagate and get processed on the worker thread), so we can't + // just await the top level join handle. + let _ = drop_receiver.await; + + // Check that the Arc has been dropped. This verifies that the inner task + // was canceled as well. + assert!(weak_notify_dropped.upgrade().is_none()); +} + +/// Tasks should be given to the least burdened worker. When spawning two tasks +/// on a pool with two empty workers the tasks should be spawned on separate +/// workers. +#[tokio::test] +async fn tasks_are_balanced() { + let pool = task::LocalPoolHandle::new(2); + + // Spawn a task so one thread has a task count of 1 + let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); + let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); + let join_handle1 = pool.spawn_pinned(|| async move { + let _ = start_sender1.send(()); + let _ = end_receiver1.await; + std::thread::current().id() + }); + + // Wait for the first task to start up + let _ = start_receiver1.await; + + // This task should be spawned on the other thread + let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); + let join_handle2 = pool.spawn_pinned(|| async move { + let _ = start_sender2.send(()); + std::thread::current().id() + }); + + // Wait for the second task to start up + let _ = start_receiver2.await; + + // Allow the first task to end + let _ = end_sender1.send(()); + + let thread_id1 = join_handle1.await.unwrap(); + let thread_id2 = join_handle2.await.unwrap(); + + // Since the first task was active when the second task spawned, they should + // be on separate workers/threads. + assert_ne!(thread_id1, thread_id2); +} diff --git a/third_party/rust/tokio-util/tests/sync_cancellation_token.rs b/third_party/rust/tokio-util/tests/sync_cancellation_token.rs new file mode 100644 index 0000000000..28ba284b6c --- /dev/null +++ b/third_party/rust/tokio-util/tests/sync_cancellation_token.rs @@ -0,0 +1,400 @@ +#![warn(rust_2018_idioms)] + +use tokio::pin; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[test] +fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_grandchild_token_through_parent_if_child_was_dropped() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let intermediate_token = token.child_token(); + let child_token = intermediate_token.child_token(); + drop(intermediate_token); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert!(!token.is_cancelled()); + assert!(child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert!(token.is_cancelled()); + assert!(child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().cloned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + drop(child_fut); + drop(parent_fut); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[test] +fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for child in &mut child_tokens { + *child = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[test] +fn cancel_only_all_descendants() { + // ARRANGE + let (waker, wake_counter) = new_count_waker(); + + let parent_token = CancellationToken::new(); + let token = parent_token.child_token(); + let sibling_token = parent_token.child_token(); + let child1_token = token.child_token(); + let child2_token = token.child_token(); + let grandchild_token = child1_token.child_token(); + let grandchild2_token = child1_token.child_token(); + let grandgrandchild_token = grandchild_token.child_token(); + + assert!(!parent_token.is_cancelled()); + assert!(!token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(!child1_token.is_cancelled()); + assert!(!child2_token.is_cancelled()); + assert!(!grandchild_token.is_cancelled()); + assert!(!grandchild2_token.is_cancelled()); + assert!(!grandgrandchild_token.is_cancelled()); + + let parent_fut = parent_token.cancelled(); + let fut = token.cancelled(); + let sibling_fut = sibling_token.cancelled(); + let child1_fut = child1_token.cancelled(); + let child2_fut = child2_token.cancelled(); + let grandchild_fut = grandchild_token.cancelled(); + let grandchild2_fut = grandchild2_token.cancelled(); + let grandgrandchild_fut = grandgrandchild_token.cancelled(); + + pin!(parent_fut); + pin!(fut); + pin!(sibling_fut); + pin!(child1_fut); + pin!(child2_fut); + pin!(grandchild_fut); + pin!(grandchild2_fut); + pin!(grandgrandchild_fut); + + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + sibling_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandgrandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + // ACT + token.cancel(); + + // ASSERT + assert_eq!(wake_counter, 6); + assert!(!parent_token.is_cancelled()); + assert!(token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(child1_token.is_cancelled()); + assert!(child2_token.is_cancelled()); + assert!(grandchild_token.is_cancelled()); + assert!(grandchild2_token.is_cancelled()); + assert!(grandgrandchild_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandgrandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 6); +} + +#[test] +fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} + +#[test] +fn derives_send_sync() { + fn assert_send<T: Send>() {} + fn assert_sync<T: Sync>() {} + + assert_send::<CancellationToken>(); + assert_sync::<CancellationToken>(); + + assert_send::<WaitForCancellationFuture<'static>>(); + assert_sync::<WaitForCancellationFuture<'static>>(); +} diff --git a/third_party/rust/tokio-util/tests/time_delay_queue.rs b/third_party/rust/tokio-util/tests/time_delay_queue.rs new file mode 100644 index 0000000000..cb163adf3a --- /dev/null +++ b/third_party/rust/tokio-util/tests/time_delay_queue.rs @@ -0,0 +1,818 @@ +#![allow(clippy::blacklisted_name)] +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::time::{self, sleep, sleep_until, Duration, Instant}; +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::time::DelayQueue; + +macro_rules! poll { + ($queue:ident) => { + $queue.enter(|cx, mut queue| queue.poll_expired(cx)) + }; +} + +macro_rules! assert_ready_some { + ($e:expr) => {{ + match assert_ready!($e) { + Some(v) => v, + None => panic!("None"), + } + }}; +} + +#[tokio::test] +async fn single_immediate_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now()); + + // Advance time by 1ms to handle thee rounding + sleep(ms(1)).await; + + assert_ready_some!(poll!(queue)); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn multi_immediate_delays() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let _k = queue.insert_at("1", Instant::now()); + let _k = queue.insert_at("2", Instant::now()); + let _k = queue.insert_at("3", Instant::now()); + + sleep(ms(1)).await; + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn single_short_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(!queue.is_woken()); + + sleep(ms(5)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn multi_delay_at_start() { + time::pause(); + + let long = 262_144 + 9 * 4096; + let delays = &[1000, 2, 234, long, 60, 10]; + + let mut queue = task::spawn(DelayQueue::new()); + + // Setup the delays + for &i in delays { + let _key = queue.insert_at(i, Instant::now() + ms(i)); + } + + assert_pending!(poll!(queue)); + assert!(!queue.is_woken()); + + let start = Instant::now(); + for elapsed in 0..1200 { + println!("elapsed: {:?}", elapsed); + let elapsed = elapsed + 1; + tokio::time::sleep_until(start + ms(elapsed)).await; + + if delays.contains(&elapsed) { + assert!(queue.is_woken()); + assert_ready!(poll!(queue)); + assert_pending!(poll!(queue)); + } else if queue.is_woken() { + let cascade = &[192, 960]; + assert!( + cascade.contains(&elapsed), + "elapsed={} dt={:?}", + elapsed, + Instant::now() - start + ); + + assert_pending!(poll!(queue)); + } + } + println!("finished multi_delay_start"); +} + +#[tokio::test] +async fn insert_in_past_fires_immediately() { + println!("running insert_in_past_fires_immediately"); + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(10)).await; + + queue.insert_at("foo", now); + + assert_ready!(poll!(queue)); + println!("finished insert_in_past_fires_immediately"); +} + +#[tokio::test] +async fn remove_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); + + sleep(ms(10)).await; + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn reset_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + let key = queue.insert_at("foo", now + ms(5)); + + assert_pending!(poll!(queue)); + sleep(ms(1)).await; + + queue.reset_at(&key, now + ms(10)); + + assert_pending!(poll!(queue)); + + sleep(ms(7)).await; + + assert!(!queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_much_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(10)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_twice() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(50)); + + sleep(ms(20)).await; + + queue.reset_at(&key, now + ms(40)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +/// Regression test: Given an entry inserted with a deadline in the past, so +/// that it is placed directly on the expired queue, reset the entry to a +/// deadline in the future. Validate that this leaves the entry and queue in an +/// internally consistent state by running an additional reset on the entry +/// before polling it to completion. +#[tokio::test] +async fn repeatedly_reset_entry_inserted_as_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + queue.reset_at(&key, now + ms(100)); + queue.reset_at(&key, now + ms(50)); + + assert_pending!(poll!(queue)); + + time::sleep_until(now + ms(60)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn remove_expired_item() { + time::pause(); + + let mut queue = DelayQueue::new(); + + let now = Instant::now(); + + sleep(ms(10)).await; + + let key = queue.insert_at("foo", now); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); +} + +/// Regression test: it should be possible to remove entries which fall in the +/// 0th slot of the internal timer wheel — that is, entries whose expiration +/// (a) falls at the beginning of one of the wheel's hierarchical levels and (b) +/// is equal to the wheel's current elapsed time. +#[tokio::test] +async fn remove_at_timer_wheel_threshold() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo", now + ms(64)); + let key2 = queue.insert_at("bar", now + ms(64)); + + sleep(ms(80)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + + match entry { + "foo" => { + let entry = queue.remove(&key2).into_inner(); + assert_eq!(entry, "bar"); + } + "bar" => { + let entry = queue.remove(&key1).into_inner(); + assert_eq!(entry, "foo"); + } + other => panic!("other: {:?}", other), + } +} + +#[tokio::test] +async fn expires_before_last_insert() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(10_000)); + + // Delay should be set to 8.192s here. + assert_pending!(poll!(queue)); + + // Delay should be set to the delay of the new item here + queue.insert_at("bar", now + ms(600)); + + assert_pending!(poll!(queue)); + + sleep(ms(600)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn multi_reset() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + queue.reset_at(&two, now + ms(350)); + queue.reset_at(&one, now + ms(400)); + + sleep(ms(310)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "two"); + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "one"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn expire_first_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "one"); +} + +#[tokio::test] +async fn expire_second_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&two, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn reset_first_expiring_item_to_expire_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let _two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + sleep(ms(250)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_before_first_after_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let _one = queue.insert_at("one", now + ms(200)); + + assert_pending!(poll!(queue)); + + let _two = queue.insert_at("two", now + ms(100)); + + sleep(ms(99)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_after_ready_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("1", now + ms(100)); + queue.insert_at("2", now + ms(100)); + queue.insert_at("3", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + queue.insert_at("foo", now + ms(500)); + } + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn reset_later_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn reset_inserted_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + // this causes the panic described in #2473 + queue.reset_at(&key, now + ms(100)); + + assert_eq!(1, queue.len()); + + sleep(ms(200)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + assert_eq!(queue.len(), 0); +} + +#[tokio::test] +async fn reset_earlier_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn insert_in_past_after_poll_fires_immediately() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(80)).await; + + assert!(!queue.is_woken()); + queue.insert_at("bar", now + ms(40)); + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn delay_queue_poll_expired_when_empty() { + let mut delay_queue = task::spawn(DelayQueue::new()); + let key = delay_queue.insert(0, std::time::Duration::from_secs(10)); + assert_pending!(poll!(delay_queue)); + + delay_queue.remove(&key); + assert!(assert_ready!(poll!(delay_queue)).is_none()); +} + +#[tokio::test(start_paused = true)] +async fn compact_expire_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +async fn compact_remove_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo1", now + ms(10)); + let key2 = queue.insert_at("foo2", now + ms(10)); + + queue.remove(&key1); + queue.remove(&key2); + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +// Trigger a re-mapping of keys in the slab due to a `compact` call and +// test removal of re-mapped keys +async fn compact_remove_remapped_keys() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + let key3 = queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` will be assigned + // new indices here + queue.compact(); + + queue.insert_at("foo5", now + ms(10)); + + // test removal of re-mapped keys + let expired3 = queue.remove(&key3); + let expired4 = queue.remove(&key4); + + assert_eq!(expired3.into_inner(), "foo3"); + assert_eq!(expired4.into_inner(), "foo4"); + + queue.compact(); + assert_eq!(queue.len(), 1); + assert_eq!(queue.capacity(), 1); +} + +#[tokio::test(start_paused = true)] +async fn compact_change_deadline() { + let mut queue = task::spawn(DelayQueue::new()); + + let mut now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` should be assigned + // new indices + queue.compact(); + + now = Instant::now(); + + queue.insert_at("foo5", now + ms(10)); + let key6 = queue.insert_at("foo6", now + ms(10)); + + queue.reset_at(&key4, now + ms(20)); + queue.reset_at(&key6, now + ms(20)); + + // foo3 and foo5 will expire + sleep(ms(10)).await; + + while res.len() < 4 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + sleep(ms(10)).await; + + while res.len() < 6 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test(start_paused = true)] +async fn remove_after_compact() { + let now = Instant::now(); + let mut queue = DelayQueue::new(); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + queue.remove(&foo_key); + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +#[tokio::test(start_paused = true)] +async fn remove_after_compact_poll() { + let now = Instant::now(); + let mut queue = task::spawn(DelayQueue::new()); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + + sleep(ms(10)).await; + assert_eq!(assert_ready_some!(poll!(queue)).key(), foo_key); + + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} diff --git a/third_party/rust/tokio-util/tests/udp.rs b/third_party/rust/tokio-util/tests/udp.rs new file mode 100644 index 0000000000..b9436a30aa --- /dev/null +++ b/third_party/rust/tokio-util/tests/udp.rs @@ -0,0 +1,132 @@ +#![warn(rust_2018_idioms)] + +use tokio::net::UdpSocket; +use tokio_stream::StreamExt; +use tokio_util::codec::{Decoder, Encoder, LinesCodec}; +use tokio_util::udp::UdpFramed; + +use bytes::{BufMut, BytesMut}; +use futures::future::try_join; +use futures::future::FutureExt; +use futures::sink::SinkExt; +use std::io; +use std::sync::Arc; + +#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] +#[tokio::test] +async fn send_framed_byte_codec() -> std::io::Result<()> { + let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + // test sending & receiving bytes + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b"4567"; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + + a_soc = a.into_inner(); + b_soc = b.into_inner(); + } + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + // test sending & receiving an empty message + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b""; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + } + + Ok(()) +} + +pub struct ByteCodec; + +impl Decoder for ByteCodec { + type Item = Vec<u8>; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Vec<u8>>, io::Error> { + let len = buf.len(); + Ok(Some(buf.split_to(len).to_vec())) + } +} + +impl Encoder<&[u8]> for ByteCodec { + type Error = io::Error; + + fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put_slice(data); + Ok(()) + } +} + +#[tokio::test] +async fn send_framed_lines_codec() -> std::io::Result<()> { + let a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + Ok(()) +} + +#[tokio::test] +async fn framed_half() -> std::io::Result<()> { + let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let b_soc = a_soc.clone(); + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + let msg = b"4\r\n5\r\n6\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); + + Ok(()) +} |