summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio/tests/uds_stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio/tests/uds_stream.rs')
-rw-r--r--third_party/rust/tokio/tests/uds_stream.rs411
1 files changed, 411 insertions, 0 deletions
diff --git a/third_party/rust/tokio/tests/uds_stream.rs b/third_party/rust/tokio/tests/uds_stream.rs
new file mode 100644
index 0000000000..5f1b4cffbc
--- /dev/null
+++ b/third_party/rust/tokio/tests/uds_stream.rs
@@ -0,0 +1,411 @@
+#![cfg(feature = "full")]
+#![warn(rust_2018_idioms)]
+#![cfg(unix)]
+
+use std::io;
+use std::task::Poll;
+
+use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
+use tokio::net::{UnixListener, UnixStream};
+use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
+
+use futures::future::{poll_fn, try_join};
+
+#[tokio::test]
+async fn accept_read_write() -> std::io::Result<()> {
+ let dir = tempfile::Builder::new()
+ .prefix("tokio-uds-tests")
+ .tempdir()
+ .unwrap();
+ let sock_path = dir.path().join("connect.sock");
+
+ let listener = UnixListener::bind(&sock_path)?;
+
+ let accept = listener.accept();
+ let connect = UnixStream::connect(&sock_path);
+ let ((mut server, _), mut client) = try_join(accept, connect).await?;
+
+ // Write to the client. TODO: Switch to write_all.
+ let write_len = client.write(b"hello").await?;
+ assert_eq!(write_len, 5);
+ drop(client);
+ // Read from the server. TODO: Switch to read_to_end.
+ let mut buf = [0u8; 5];
+ server.read_exact(&mut buf).await?;
+ assert_eq!(&buf, b"hello");
+ let len = server.read(&mut buf).await?;
+ assert_eq!(len, 0);
+ Ok(())
+}
+
+#[tokio::test]
+async fn shutdown() -> std::io::Result<()> {
+ let dir = tempfile::Builder::new()
+ .prefix("tokio-uds-tests")
+ .tempdir()
+ .unwrap();
+ let sock_path = dir.path().join("connect.sock");
+
+ let listener = UnixListener::bind(&sock_path)?;
+
+ let accept = listener.accept();
+ let connect = UnixStream::connect(&sock_path);
+ let ((mut server, _), mut client) = try_join(accept, connect).await?;
+
+ // Shut down the client
+ AsyncWriteExt::shutdown(&mut client).await?;
+ // Read from the server should return 0 to indicate the channel has been closed.
+ let mut buf = [0u8; 1];
+ let n = server.read(&mut buf).await?;
+ assert_eq!(n, 0);
+ Ok(())
+}
+
+#[tokio::test]
+async fn try_read_write() -> std::io::Result<()> {
+ let msg = b"hello world";
+
+ let dir = tempfile::tempdir()?;
+ let bind_path = dir.path().join("bind.sock");
+
+ // Create listener
+ let listener = UnixListener::bind(&bind_path)?;
+
+ // Create socket pair
+ let client = UnixStream::connect(&bind_path).await?;
+
+ let (server, _) = listener.accept().await?;
+ let mut written = msg.to_vec();
+
+ // Track the server receiving data
+ let mut readable = task::spawn(server.readable());
+ assert_pending!(readable.poll());
+
+ // Write data.
+ client.writable().await?;
+ assert_eq!(msg.len(), client.try_write(msg)?);
+
+ // The task should be notified
+ while !readable.is_woken() {
+ tokio::task::yield_now().await;
+ }
+
+ // Fill the write buffer using non-vectored I/O
+ loop {
+ // Still ready
+ let mut writable = task::spawn(client.writable());
+ assert_ready_ok!(writable.poll());
+
+ match client.try_write(msg) {
+ Ok(n) => written.extend(&msg[..n]),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ break;
+ }
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ {
+ // Write buffer full
+ let mut writable = task::spawn(client.writable());
+ assert_pending!(writable.poll());
+
+ // Drain the socket from the server end using non-vectored I/O
+ let mut read = vec![0; written.len()];
+ let mut i = 0;
+
+ while i < read.len() {
+ server.readable().await?;
+
+ match server.try_read(&mut read[i..]) {
+ Ok(n) => i += n,
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ assert_eq!(read, written);
+ }
+
+ written.clear();
+ client.writable().await.unwrap();
+
+ // Fill the write buffer using vectored I/O
+ let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect();
+ loop {
+ // Still ready
+ let mut writable = task::spawn(client.writable());
+ assert_ready_ok!(writable.poll());
+
+ match client.try_write_vectored(&msg_bufs) {
+ Ok(n) => written.extend(&msg[..n]),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ break;
+ }
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ {
+ // Write buffer full
+ let mut writable = task::spawn(client.writable());
+ assert_pending!(writable.poll());
+
+ // Drain the socket from the server end using vectored I/O
+ let mut read = vec![0; written.len()];
+ let mut i = 0;
+
+ while i < read.len() {
+ server.readable().await?;
+
+ let mut bufs: Vec<_> = read[i..]
+ .chunks_mut(0x10000)
+ .map(io::IoSliceMut::new)
+ .collect();
+ match server.try_read_vectored(&mut bufs) {
+ Ok(n) => i += n,
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ assert_eq!(read, written);
+ }
+
+ // Now, we listen for shutdown
+ drop(client);
+
+ loop {
+ let ready = server.ready(Interest::READABLE).await?;
+
+ if ready.is_read_closed() {
+ break;
+ } else {
+ tokio::task::yield_now().await;
+ }
+ }
+
+ Ok(())
+}
+
+async fn create_pair() -> (UnixStream, UnixStream) {
+ let dir = assert_ok!(tempfile::tempdir());
+ let bind_path = dir.path().join("bind.sock");
+
+ let listener = assert_ok!(UnixListener::bind(&bind_path));
+
+ let accept = listener.accept();
+ let connect = UnixStream::connect(&bind_path);
+ let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
+
+ (client, server)
+}
+
+macro_rules! assert_readable_by_polling {
+ ($stream:expr) => {
+ assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
+ };
+}
+
+macro_rules! assert_not_readable_by_polling {
+ ($stream:expr) => {
+ poll_fn(|cx| {
+ assert_pending!($stream.poll_read_ready(cx));
+ Poll::Ready(())
+ })
+ .await;
+ };
+}
+
+macro_rules! assert_writable_by_polling {
+ ($stream:expr) => {
+ assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
+ };
+}
+
+macro_rules! assert_not_writable_by_polling {
+ ($stream:expr) => {
+ poll_fn(|cx| {
+ assert_pending!($stream.poll_write_ready(cx));
+ Poll::Ready(())
+ })
+ .await;
+ };
+}
+
+#[tokio::test]
+async fn poll_read_ready() {
+ let (mut client, mut server) = create_pair().await;
+
+ // Initial state - not readable.
+ assert_not_readable_by_polling!(server);
+
+ // There is data in the buffer - readable.
+ assert_ok!(client.write_all(b"ping").await);
+ assert_readable_by_polling!(server);
+
+ // Readable until calls to `poll_read` return `Poll::Pending`.
+ let mut buf = [0u8; 4];
+ assert_ok!(server.read_exact(&mut buf).await);
+ assert_readable_by_polling!(server);
+ read_until_pending(&mut server);
+ assert_not_readable_by_polling!(server);
+
+ // Detect the client disconnect.
+ drop(client);
+ assert_readable_by_polling!(server);
+}
+
+#[tokio::test]
+async fn poll_write_ready() {
+ let (mut client, server) = create_pair().await;
+
+ // Initial state - writable.
+ assert_writable_by_polling!(client);
+
+ // No space to write - not writable.
+ write_until_pending(&mut client);
+ assert_not_writable_by_polling!(client);
+
+ // Detect the server disconnect.
+ drop(server);
+ assert_writable_by_polling!(client);
+}
+
+fn read_until_pending(stream: &mut UnixStream) {
+ let mut buf = vec![0u8; 1024 * 1024];
+ loop {
+ match stream.try_read(&mut buf) {
+ Ok(_) => (),
+ Err(err) => {
+ assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
+ break;
+ }
+ }
+ }
+}
+
+fn write_until_pending(stream: &mut UnixStream) {
+ let buf = vec![0u8; 1024 * 1024];
+ loop {
+ match stream.try_write(&buf) {
+ Ok(_) => (),
+ Err(err) => {
+ assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
+ break;
+ }
+ }
+ }
+}
+
+#[tokio::test]
+async fn try_read_buf() -> std::io::Result<()> {
+ let msg = b"hello world";
+
+ let dir = tempfile::tempdir()?;
+ let bind_path = dir.path().join("bind.sock");
+
+ // Create listener
+ let listener = UnixListener::bind(&bind_path)?;
+
+ // Create socket pair
+ let client = UnixStream::connect(&bind_path).await?;
+
+ let (server, _) = listener.accept().await?;
+ let mut written = msg.to_vec();
+
+ // Track the server receiving data
+ let mut readable = task::spawn(server.readable());
+ assert_pending!(readable.poll());
+
+ // Write data.
+ client.writable().await?;
+ assert_eq!(msg.len(), client.try_write(msg)?);
+
+ // The task should be notified
+ while !readable.is_woken() {
+ tokio::task::yield_now().await;
+ }
+
+ // Fill the write buffer
+ loop {
+ // Still ready
+ let mut writable = task::spawn(client.writable());
+ assert_ready_ok!(writable.poll());
+
+ match client.try_write(msg) {
+ Ok(n) => written.extend(&msg[..n]),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ break;
+ }
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ {
+ // Write buffer full
+ let mut writable = task::spawn(client.writable());
+ assert_pending!(writable.poll());
+
+ // Drain the socket from the server end
+ let mut read = Vec::with_capacity(written.len());
+ let mut i = 0;
+
+ while i < read.capacity() {
+ server.readable().await?;
+
+ match server.try_read_buf(&mut read) {
+ Ok(n) => i += n,
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
+ Err(e) => panic!("error = {:?}", e),
+ }
+ }
+
+ assert_eq!(read, written);
+ }
+
+ // Now, we listen for shutdown
+ drop(client);
+
+ loop {
+ let ready = server.ready(Interest::READABLE).await?;
+
+ if ready.is_read_closed() {
+ break;
+ } else {
+ tokio::task::yield_now().await;
+ }
+ }
+
+ Ok(())
+}
+
+// https://github.com/tokio-rs/tokio/issues/3879
+#[tokio::test]
+#[cfg(not(target_os = "macos"))]
+async fn epollhup() -> io::Result<()> {
+ let dir = tempfile::Builder::new()
+ .prefix("tokio-uds-tests")
+ .tempdir()
+ .unwrap();
+ let sock_path = dir.path().join("connect.sock");
+
+ let listener = UnixListener::bind(&sock_path)?;
+ let connect = UnixStream::connect(&sock_path);
+ tokio::pin!(connect);
+
+ // Poll `connect` once.
+ poll_fn(|cx| {
+ use std::future::Future;
+
+ assert_pending!(connect.as_mut().poll(cx));
+ Poll::Ready(())
+ })
+ .await;
+
+ drop(listener);
+
+ let err = connect.await.unwrap_err();
+ assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
+ Ok(())
+}