summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio/src/runtime/tests
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio/src/runtime/tests')
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs142
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_blocking.rs81
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_join_set.rs82
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_local.rs47
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs48
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_pool.rs430
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_queue.rs208
-rw-r--r--third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs28
-rw-r--r--third_party/rust/tokio/src/runtime/tests/mod.rs50
-rw-r--r--third_party/rust/tokio/src/runtime/tests/queue.rs248
-rw-r--r--third_party/rust/tokio/src/runtime/tests/task.rs288
-rw-r--r--third_party/rust/tokio/src/runtime/tests/task_combinations.rs380
12 files changed, 2032 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs
new file mode 100644
index 0000000000..a772603f71
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs
@@ -0,0 +1,142 @@
+use crate::loom::sync::atomic::AtomicUsize;
+use crate::loom::sync::Arc;
+use crate::loom::thread;
+use crate::runtime::{Builder, Runtime};
+use crate::sync::oneshot::{self, Receiver};
+use crate::task;
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::atomic::Ordering::{Acquire, Release};
+use std::task::{Context, Poll};
+
+fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) {
+ let (tx, rx) = oneshot::channel();
+ let num_polls = Arc::new(AtomicUsize::new(0));
+ rt.spawn(async move {
+ for _ in 0..12 {
+ task::yield_now().await;
+ }
+ tx.send(()).unwrap();
+ });
+
+ rt.block_on(async {
+ BlockedFuture {
+ rx,
+ num_polls: num_polls.clone(),
+ }
+ .await;
+ });
+
+ let polls = num_polls.load(Acquire);
+ assert!(polls <= at_most_polls);
+}
+
+#[test]
+fn block_on_num_polls() {
+ loom::model(|| {
+ // we expect at most 4 number of polls because there are three points at
+ // which we poll the future and an opportunity for a false-positive.. At
+ // any of these points it can be ready:
+ //
+ // - when we fail to steal the parker and we block on a notification
+ // that it is available.
+ //
+ // - when we steal the parker and we schedule the future
+ //
+ // - when the future is woken up and we have ran the max number of tasks
+ // for the current tick or there are no more tasks to run.
+ //
+ // - a thread is notified that the parker is available but a third
+ // thread acquires it before the notified thread can.
+ //
+ let at_most = 4;
+
+ let rt1 = Arc::new(Builder::new_current_thread().build().unwrap());
+ let rt2 = rt1.clone();
+ let rt3 = rt1.clone();
+
+ let th1 = thread::spawn(move || assert_at_most_num_polls(rt1, at_most));
+ let th2 = thread::spawn(move || assert_at_most_num_polls(rt2, at_most));
+ let th3 = thread::spawn(move || assert_at_most_num_polls(rt3, at_most));
+
+ th1.join().unwrap();
+ th2.join().unwrap();
+ th3.join().unwrap();
+ });
+}
+
+#[test]
+fn assert_no_unnecessary_polls() {
+ loom::model(|| {
+ // // After we poll outer future, woken should reset to false
+ let rt = Builder::new_current_thread().build().unwrap();
+ let (tx, rx) = oneshot::channel();
+ let pending_cnt = Arc::new(AtomicUsize::new(0));
+
+ rt.spawn(async move {
+ for _ in 0..24 {
+ task::yield_now().await;
+ }
+ tx.send(()).unwrap();
+ });
+
+ let pending_cnt_clone = pending_cnt.clone();
+ rt.block_on(async move {
+ // use task::yield_now() to ensure woken set to true
+ // ResetFuture will be polled at most once
+ // Here comes two cases
+ // 1. recv no message from channel, ResetFuture will be polled
+ // but get Pending and we record ResetFuture.pending_cnt ++.
+ // Then when message arrive, ResetFuture returns Ready. So we
+ // expect ResetFuture.pending_cnt = 1
+ // 2. recv message from channel, ResetFuture returns Ready immediately.
+ // We expect ResetFuture.pending_cnt = 0
+ task::yield_now().await;
+ ResetFuture {
+ rx,
+ pending_cnt: pending_cnt_clone,
+ }
+ .await;
+ });
+
+ let pending_cnt = pending_cnt.load(Acquire);
+ assert!(pending_cnt <= 1);
+ });
+}
+
+struct BlockedFuture {
+ rx: Receiver<()>,
+ num_polls: Arc<AtomicUsize>,
+}
+
+impl Future for BlockedFuture {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ self.num_polls.fetch_add(1, Release);
+
+ match Pin::new(&mut self.rx).poll(cx) {
+ Poll::Pending => Poll::Pending,
+ _ => Poll::Ready(()),
+ }
+ }
+}
+
+struct ResetFuture {
+ rx: Receiver<()>,
+ pending_cnt: Arc<AtomicUsize>,
+}
+
+impl Future for ResetFuture {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ match Pin::new(&mut self.rx).poll(cx) {
+ Poll::Pending => {
+ self.pending_cnt.fetch_add(1, Release);
+ Poll::Pending
+ }
+ _ => Poll::Ready(()),
+ }
+ }
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs
new file mode 100644
index 0000000000..89de85e436
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs
@@ -0,0 +1,81 @@
+use crate::runtime::{self, Runtime};
+
+use std::sync::Arc;
+
+#[test]
+fn blocking_shutdown() {
+ loom::model(|| {
+ let v = Arc::new(());
+
+ let rt = mk_runtime(1);
+ {
+ let _enter = rt.enter();
+ for _ in 0..2 {
+ let v = v.clone();
+ crate::task::spawn_blocking(move || {
+ assert!(1 < Arc::strong_count(&v));
+ });
+ }
+ }
+
+ drop(rt);
+ assert_eq!(1, Arc::strong_count(&v));
+ });
+}
+
+#[test]
+fn spawn_mandatory_blocking_should_always_run() {
+ use crate::runtime::tests::loom_oneshot;
+ loom::model(|| {
+ let rt = runtime::Builder::new_current_thread().build().unwrap();
+
+ let (tx, rx) = loom_oneshot::channel();
+ let _enter = rt.enter();
+ runtime::spawn_blocking(|| {});
+ runtime::spawn_mandatory_blocking(move || {
+ let _ = tx.send(());
+ })
+ .unwrap();
+
+ drop(rt);
+
+ // This call will deadlock if `spawn_mandatory_blocking` doesn't run.
+ let () = rx.recv();
+ });
+}
+
+#[test]
+fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() {
+ use crate::runtime::tests::loom_oneshot;
+ loom::model(|| {
+ let rt = runtime::Builder::new_current_thread().build().unwrap();
+ let handle = rt.handle().clone();
+
+ // Drop the runtime in a different thread
+ {
+ loom::thread::spawn(move || {
+ drop(rt);
+ });
+ }
+
+ let _enter = handle.enter();
+ let (tx, rx) = loom_oneshot::channel();
+ let handle = runtime::spawn_mandatory_blocking(move || {
+ let _ = tx.send(());
+ });
+
+ // handle.is_some() means that `spawn_mandatory_blocking`
+ // promised us to run the blocking task
+ if handle.is_some() {
+ // This call will deadlock if `spawn_mandatory_blocking` doesn't run.
+ let () = rx.recv();
+ }
+ });
+}
+
+fn mk_runtime(num_threads: usize) -> Runtime {
+ runtime::Builder::new_multi_thread()
+ .worker_threads(num_threads)
+ .build()
+ .unwrap()
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs
new file mode 100644
index 0000000000..e87ddb0140
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs
@@ -0,0 +1,82 @@
+use crate::runtime::Builder;
+use crate::task::JoinSet;
+
+#[test]
+fn test_join_set() {
+ loom::model(|| {
+ let rt = Builder::new_multi_thread()
+ .worker_threads(1)
+ .build()
+ .unwrap();
+ let mut set = JoinSet::new();
+
+ rt.block_on(async {
+ assert_eq!(set.len(), 0);
+ set.spawn(async { () });
+ assert_eq!(set.len(), 1);
+ set.spawn(async { () });
+ assert_eq!(set.len(), 2);
+ let () = set.join_one().await.unwrap().unwrap();
+ assert_eq!(set.len(), 1);
+ set.spawn(async { () });
+ assert_eq!(set.len(), 2);
+ let () = set.join_one().await.unwrap().unwrap();
+ assert_eq!(set.len(), 1);
+ let () = set.join_one().await.unwrap().unwrap();
+ assert_eq!(set.len(), 0);
+ set.spawn(async { () });
+ assert_eq!(set.len(), 1);
+ });
+
+ drop(set);
+ drop(rt);
+ });
+}
+
+#[test]
+fn abort_all_during_completion() {
+ use std::sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ };
+
+ // These booleans assert that at least one execution had the task complete first, and that at
+ // least one execution had the task be cancelled before it completed.
+ let complete_happened = Arc::new(AtomicBool::new(false));
+ let cancel_happened = Arc::new(AtomicBool::new(false));
+
+ {
+ let complete_happened = complete_happened.clone();
+ let cancel_happened = cancel_happened.clone();
+ loom::model(move || {
+ let rt = Builder::new_multi_thread()
+ .worker_threads(1)
+ .build()
+ .unwrap();
+
+ let mut set = JoinSet::new();
+
+ rt.block_on(async {
+ set.spawn(async { () });
+ set.abort_all();
+
+ match set.join_one().await {
+ Ok(Some(())) => complete_happened.store(true, SeqCst),
+ Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst),
+ Err(err) => panic!("fail: {}", err),
+ Ok(None) => {
+ unreachable!("Aborting the task does not remove it from the JoinSet.")
+ }
+ }
+
+ assert!(matches!(set.join_one().await, Ok(None)));
+ });
+
+ drop(set);
+ drop(rt);
+ });
+ }
+
+ assert!(complete_happened.load(SeqCst));
+ assert!(cancel_happened.load(SeqCst));
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_local.rs b/third_party/rust/tokio/src/runtime/tests/loom_local.rs
new file mode 100644
index 0000000000..d9a07a45f0
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_local.rs
@@ -0,0 +1,47 @@
+use crate::runtime::tests::loom_oneshot as oneshot;
+use crate::runtime::Builder;
+use crate::task::LocalSet;
+
+use std::task::Poll;
+
+/// Waking a runtime will attempt to push a task into a queue of notifications
+/// in the runtime, however the tasks in such a queue usually have a reference
+/// to the runtime itself. This means that if they are not properly removed at
+/// runtime shutdown, this will cause a memory leak.
+///
+/// This test verifies that waking something during shutdown of a LocalSet does
+/// not result in tasks lingering in the queue once shutdown is complete. This
+/// is verified using loom's leak finder.
+#[test]
+fn wake_during_shutdown() {
+ loom::model(|| {
+ let rt = Builder::new_current_thread().build().unwrap();
+ let ls = LocalSet::new();
+
+ let (send, recv) = oneshot::channel();
+
+ ls.spawn_local(async move {
+ let mut send = Some(send);
+
+ let () = futures::future::poll_fn(|cx| {
+ if let Some(send) = send.take() {
+ send.send(cx.waker().clone());
+ }
+
+ Poll::Pending
+ })
+ .await;
+ });
+
+ let handle = loom::thread::spawn(move || {
+ let waker = recv.recv();
+ waker.wake();
+ });
+
+ ls.block_on(&rt, crate::task::yield_now());
+
+ drop(ls);
+ handle.join().unwrap();
+ drop(rt);
+ });
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs
new file mode 100644
index 0000000000..87eb638642
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs
@@ -0,0 +1,48 @@
+use crate::loom::sync::{Arc, Mutex};
+use loom::sync::Notify;
+
+pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) {
+ let inner = Arc::new(Inner {
+ notify: Notify::new(),
+ value: Mutex::new(None),
+ });
+
+ let tx = Sender {
+ inner: inner.clone(),
+ };
+ let rx = Receiver { inner };
+
+ (tx, rx)
+}
+
+pub(crate) struct Sender<T> {
+ inner: Arc<Inner<T>>,
+}
+
+pub(crate) struct Receiver<T> {
+ inner: Arc<Inner<T>>,
+}
+
+struct Inner<T> {
+ notify: Notify,
+ value: Mutex<Option<T>>,
+}
+
+impl<T> Sender<T> {
+ pub(crate) fn send(self, value: T) {
+ *self.inner.value.lock() = Some(value);
+ self.inner.notify.notify();
+ }
+}
+
+impl<T> Receiver<T> {
+ pub(crate) fn recv(self) -> T {
+ loop {
+ if let Some(v) = self.inner.value.lock().take() {
+ return v;
+ }
+
+ self.inner.notify.wait();
+ }
+ }
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_pool.rs b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs
new file mode 100644
index 0000000000..b3ecd43124
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs
@@ -0,0 +1,430 @@
+/// Full runtime loom tests. These are heavy tests and take significant time to
+/// run on CI.
+///
+/// Use `LOOM_MAX_PREEMPTIONS=1` to do a "quick" run as a smoke test.
+///
+/// In order to speed up the C
+use crate::future::poll_fn;
+use crate::runtime::tests::loom_oneshot as oneshot;
+use crate::runtime::{self, Runtime};
+use crate::{spawn, task};
+use tokio_test::assert_ok;
+
+use loom::sync::atomic::{AtomicBool, AtomicUsize};
+use loom::sync::Arc;
+
+use pin_project_lite::pin_project;
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::atomic::Ordering::{Relaxed, SeqCst};
+use std::task::{Context, Poll};
+
+mod atomic_take {
+ use loom::sync::atomic::AtomicBool;
+ use std::mem::MaybeUninit;
+ use std::sync::atomic::Ordering::SeqCst;
+
+ pub(super) struct AtomicTake<T> {
+ inner: MaybeUninit<T>,
+ taken: AtomicBool,
+ }
+
+ impl<T> AtomicTake<T> {
+ pub(super) fn new(value: T) -> Self {
+ Self {
+ inner: MaybeUninit::new(value),
+ taken: AtomicBool::new(false),
+ }
+ }
+
+ pub(super) fn take(&self) -> Option<T> {
+ // safety: Only one thread will see the boolean change from false
+ // to true, so that thread is able to take the value.
+ match self.taken.fetch_or(true, SeqCst) {
+ false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) },
+ true => None,
+ }
+ }
+ }
+
+ impl<T> Drop for AtomicTake<T> {
+ fn drop(&mut self) {
+ drop(self.take());
+ }
+ }
+}
+
+#[derive(Clone)]
+struct AtomicOneshot<T> {
+ value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>,
+}
+impl<T> AtomicOneshot<T> {
+ fn new(sender: oneshot::Sender<T>) -> Self {
+ Self {
+ value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)),
+ }
+ }
+
+ fn assert_send(&self, value: T) {
+ self.value.take().unwrap().send(value);
+ }
+}
+
+/// Tests are divided into groups to make the runs faster on CI.
+mod group_a {
+ use super::*;
+
+ #[test]
+ fn racy_shutdown() {
+ loom::model(|| {
+ let pool = mk_pool(1);
+
+ // here's the case we want to exercise:
+ //
+ // a worker that still has tasks in its local queue gets sent to the blocking pool (due to
+ // block_in_place). the blocking pool is shut down, so drops the worker. the worker's
+ // shutdown method never gets run.
+ //
+ // we do this by spawning two tasks on one worker, the first of which does block_in_place,
+ // and then immediately drop the pool.
+
+ pool.spawn(track(async {
+ crate::task::block_in_place(|| {});
+ }));
+ pool.spawn(track(async {}));
+ drop(pool);
+ });
+ }
+
+ #[test]
+ fn pool_multi_spawn() {
+ loom::model(|| {
+ let pool = mk_pool(2);
+ let c1 = Arc::new(AtomicUsize::new(0));
+
+ let (tx, rx) = oneshot::channel();
+ let tx1 = AtomicOneshot::new(tx);
+
+ // Spawn a task
+ let c2 = c1.clone();
+ let tx2 = tx1.clone();
+ pool.spawn(track(async move {
+ spawn(track(async move {
+ if 1 == c1.fetch_add(1, Relaxed) {
+ tx1.assert_send(());
+ }
+ }));
+ }));
+
+ // Spawn a second task
+ pool.spawn(track(async move {
+ spawn(track(async move {
+ if 1 == c2.fetch_add(1, Relaxed) {
+ tx2.assert_send(());
+ }
+ }));
+ }));
+
+ rx.recv();
+ });
+ }
+
+ fn only_blocking_inner(first_pending: bool) {
+ loom::model(move || {
+ let pool = mk_pool(1);
+ let (block_tx, block_rx) = oneshot::channel();
+
+ pool.spawn(track(async move {
+ crate::task::block_in_place(move || {
+ block_tx.send(());
+ });
+ if first_pending {
+ task::yield_now().await
+ }
+ }));
+
+ block_rx.recv();
+ drop(pool);
+ });
+ }
+
+ #[test]
+ fn only_blocking_without_pending() {
+ only_blocking_inner(false)
+ }
+
+ #[test]
+ fn only_blocking_with_pending() {
+ only_blocking_inner(true)
+ }
+}
+
+mod group_b {
+ use super::*;
+
+ fn blocking_and_regular_inner(first_pending: bool) {
+ const NUM: usize = 3;
+ loom::model(move || {
+ let pool = mk_pool(1);
+ let cnt = Arc::new(AtomicUsize::new(0));
+
+ let (block_tx, block_rx) = oneshot::channel();
+ let (done_tx, done_rx) = oneshot::channel();
+ let done_tx = AtomicOneshot::new(done_tx);
+
+ pool.spawn(track(async move {
+ crate::task::block_in_place(move || {
+ block_tx.send(());
+ });
+ if first_pending {
+ task::yield_now().await
+ }
+ }));
+
+ for _ in 0..NUM {
+ let cnt = cnt.clone();
+ let done_tx = done_tx.clone();
+
+ pool.spawn(track(async move {
+ if NUM == cnt.fetch_add(1, Relaxed) + 1 {
+ done_tx.assert_send(());
+ }
+ }));
+ }
+
+ done_rx.recv();
+ block_rx.recv();
+
+ drop(pool);
+ });
+ }
+
+ #[test]
+ fn blocking_and_regular() {
+ blocking_and_regular_inner(false);
+ }
+
+ #[test]
+ fn blocking_and_regular_with_pending() {
+ blocking_and_regular_inner(true);
+ }
+
+ #[test]
+ fn join_output() {
+ loom::model(|| {
+ let rt = mk_pool(1);
+
+ rt.block_on(async {
+ let t = crate::spawn(track(async { "hello" }));
+
+ let out = assert_ok!(t.await);
+ assert_eq!("hello", out.into_inner());
+ });
+ });
+ }
+
+ #[test]
+ fn poll_drop_handle_then_drop() {
+ loom::model(|| {
+ let rt = mk_pool(1);
+
+ rt.block_on(async move {
+ let mut t = crate::spawn(track(async { "hello" }));
+
+ poll_fn(|cx| {
+ let _ = Pin::new(&mut t).poll(cx);
+ Poll::Ready(())
+ })
+ .await;
+ });
+ })
+ }
+
+ #[test]
+ fn complete_block_on_under_load() {
+ loom::model(|| {
+ let pool = mk_pool(1);
+
+ pool.block_on(async {
+ // Trigger a re-schedule
+ crate::spawn(track(async {
+ for _ in 0..2 {
+ task::yield_now().await;
+ }
+ }));
+
+ gated2(true).await
+ });
+ });
+ }
+
+ #[test]
+ fn shutdown_with_notification() {
+ use crate::sync::oneshot;
+
+ loom::model(|| {
+ let rt = mk_pool(2);
+ let (done_tx, done_rx) = oneshot::channel::<()>();
+
+ rt.spawn(track(async move {
+ let (tx, rx) = oneshot::channel::<()>();
+
+ crate::spawn(async move {
+ crate::task::spawn_blocking(move || {
+ let _ = tx.send(());
+ });
+
+ let _ = done_rx.await;
+ });
+
+ let _ = rx.await;
+
+ let _ = done_tx.send(());
+ }));
+ });
+ }
+}
+
+mod group_c {
+ use super::*;
+
+ #[test]
+ fn pool_shutdown() {
+ loom::model(|| {
+ let pool = mk_pool(2);
+
+ pool.spawn(track(async move {
+ gated2(true).await;
+ }));
+
+ pool.spawn(track(async move {
+ gated2(false).await;
+ }));
+
+ drop(pool);
+ });
+ }
+}
+
+mod group_d {
+ use super::*;
+
+ #[test]
+ fn pool_multi_notify() {
+ loom::model(|| {
+ let pool = mk_pool(2);
+
+ let c1 = Arc::new(AtomicUsize::new(0));
+
+ let (done_tx, done_rx) = oneshot::channel();
+ let done_tx1 = AtomicOneshot::new(done_tx);
+ let done_tx2 = done_tx1.clone();
+
+ // Spawn a task
+ let c2 = c1.clone();
+ pool.spawn(track(async move {
+ gated().await;
+ gated().await;
+
+ if 1 == c1.fetch_add(1, Relaxed) {
+ done_tx1.assert_send(());
+ }
+ }));
+
+ // Spawn a second task
+ pool.spawn(track(async move {
+ gated().await;
+ gated().await;
+
+ if 1 == c2.fetch_add(1, Relaxed) {
+ done_tx2.assert_send(());
+ }
+ }));
+
+ done_rx.recv();
+ });
+ }
+}
+
+fn mk_pool(num_threads: usize) -> Runtime {
+ runtime::Builder::new_multi_thread()
+ .worker_threads(num_threads)
+ .build()
+ .unwrap()
+}
+
+fn gated() -> impl Future<Output = &'static str> {
+ gated2(false)
+}
+
+fn gated2(thread: bool) -> impl Future<Output = &'static str> {
+ use loom::thread;
+ use std::sync::Arc;
+
+ let gate = Arc::new(AtomicBool::new(false));
+ let mut fired = false;
+
+ poll_fn(move |cx| {
+ if !fired {
+ let gate = gate.clone();
+ let waker = cx.waker().clone();
+
+ if thread {
+ thread::spawn(move || {
+ gate.store(true, SeqCst);
+ waker.wake_by_ref();
+ });
+ } else {
+ spawn(track(async move {
+ gate.store(true, SeqCst);
+ waker.wake_by_ref();
+ }));
+ }
+
+ fired = true;
+
+ return Poll::Pending;
+ }
+
+ if gate.load(SeqCst) {
+ Poll::Ready("hello world")
+ } else {
+ Poll::Pending
+ }
+ })
+}
+
+fn track<T: Future>(f: T) -> Track<T> {
+ Track {
+ inner: f,
+ arc: Arc::new(()),
+ }
+}
+
+pin_project! {
+ struct Track<T> {
+ #[pin]
+ inner: T,
+ // Arc is used to hook into loom's leak tracking.
+ arc: Arc<()>,
+ }
+}
+
+impl<T> Track<T> {
+ fn into_inner(self) -> T {
+ self.inner
+ }
+}
+
+impl<T: Future> Future for Track<T> {
+ type Output = Track<T::Output>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+
+ Poll::Ready(Track {
+ inner: ready!(me.inner.poll(cx)),
+ arc: me.arc.clone(),
+ })
+ }
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_queue.rs b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs
new file mode 100644
index 0000000000..b5f78d7ebe
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs
@@ -0,0 +1,208 @@
+use crate::runtime::blocking::NoopSchedule;
+use crate::runtime::task::Inject;
+use crate::runtime::{queue, MetricsBatch};
+
+use loom::thread;
+
+#[test]
+fn basic() {
+ loom::model(|| {
+ let (steal, mut local) = queue::local();
+ let inject = Inject::new();
+ let mut metrics = MetricsBatch::new();
+
+ let th = thread::spawn(move || {
+ let mut metrics = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+ let mut n = 0;
+
+ for _ in 0..3 {
+ if steal.steal_into(&mut local, &mut metrics).is_some() {
+ n += 1;
+ }
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+ }
+
+ n
+ });
+
+ let mut n = 0;
+
+ for _ in 0..2 {
+ for _ in 0..2 {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ if local.pop().is_some() {
+ n += 1;
+ }
+
+ // Push another task
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+ }
+
+ while inject.pop().is_some() {
+ n += 1;
+ }
+
+ n += th.join().unwrap();
+
+ assert_eq!(6, n);
+ });
+}
+
+#[test]
+fn steal_overflow() {
+ loom::model(|| {
+ let (steal, mut local) = queue::local();
+ let inject = Inject::new();
+ let mut metrics = MetricsBatch::new();
+
+ let th = thread::spawn(move || {
+ let mut metrics = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+ let mut n = 0;
+
+ if steal.steal_into(&mut local, &mut metrics).is_some() {
+ n += 1;
+ }
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ n
+ });
+
+ let mut n = 0;
+
+ // push a task, pop a task
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+
+ if local.pop().is_some() {
+ n += 1;
+ }
+
+ for _ in 0..6 {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ n += th.join().unwrap();
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ while inject.pop().is_some() {
+ n += 1;
+ }
+
+ assert_eq!(7, n);
+ });
+}
+
+#[test]
+fn multi_stealer() {
+ const NUM_TASKS: usize = 5;
+
+ fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize {
+ let mut metrics = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+
+ if steal.steal_into(&mut local, &mut metrics).is_none() {
+ return 0;
+ }
+
+ let mut n = 1;
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ n
+ }
+
+ loom::model(|| {
+ let (steal, mut local) = queue::local();
+ let inject = Inject::new();
+ let mut metrics = MetricsBatch::new();
+
+ // Push work
+ for _ in 0..NUM_TASKS {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ let th1 = {
+ let steal = steal.clone();
+ thread::spawn(move || steal_tasks(steal))
+ };
+
+ let th2 = thread::spawn(move || steal_tasks(steal));
+
+ let mut n = 0;
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ while inject.pop().is_some() {
+ n += 1;
+ }
+
+ n += th1.join().unwrap();
+ n += th2.join().unwrap();
+
+ assert_eq!(n, NUM_TASKS);
+ });
+}
+
+#[test]
+fn chained_steal() {
+ loom::model(|| {
+ let mut metrics = MetricsBatch::new();
+ let (s1, mut l1) = queue::local();
+ let (s2, mut l2) = queue::local();
+ let inject = Inject::new();
+
+ // Load up some tasks
+ for _ in 0..4 {
+ let (task, _) = super::unowned(async {});
+ l1.push_back(task, &inject, &mut metrics);
+
+ let (task, _) = super::unowned(async {});
+ l2.push_back(task, &inject, &mut metrics);
+ }
+
+ // Spawn a task to steal from **our** queue
+ let th = thread::spawn(move || {
+ let mut metrics = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+ s1.steal_into(&mut local, &mut metrics);
+
+ while local.pop().is_some() {}
+ });
+
+ // Drain our tasks, then attempt to steal
+ while l1.pop().is_some() {}
+
+ s2.steal_into(&mut l1, &mut metrics);
+
+ th.join().unwrap();
+
+ while l1.pop().is_some() {}
+ while l2.pop().is_some() {}
+ while inject.pop().is_some() {}
+ });
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs
new file mode 100644
index 0000000000..6fbc4bfded
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs
@@ -0,0 +1,28 @@
+use crate::runtime::{Builder, Handle};
+
+#[test]
+fn join_handle_cancel_on_shutdown() {
+ let mut builder = loom::model::Builder::new();
+ builder.preemption_bound = Some(2);
+ builder.check(|| {
+ use futures::future::FutureExt;
+
+ let rt = Builder::new_multi_thread()
+ .worker_threads(2)
+ .build()
+ .unwrap();
+
+ let handle = rt.block_on(async move { Handle::current() });
+
+ let jh1 = handle.spawn(futures::future::pending::<()>());
+
+ drop(rt);
+
+ let jh2 = handle.spawn(futures::future::pending::<()>());
+
+ let err1 = jh1.now_or_never().unwrap().unwrap_err();
+ let err2 = jh2.now_or_never().unwrap().unwrap_err();
+ assert!(err1.is_cancelled());
+ assert!(err2.is_cancelled());
+ });
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/mod.rs b/third_party/rust/tokio/src/runtime/tests/mod.rs
new file mode 100644
index 0000000000..4b49698a86
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/mod.rs
@@ -0,0 +1,50 @@
+use self::unowned_wrapper::unowned;
+
+mod unowned_wrapper {
+ use crate::runtime::blocking::NoopSchedule;
+ use crate::runtime::task::{JoinHandle, Notified};
+
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>)
+ where
+ T: std::future::Future + Send + 'static,
+ T::Output: Send + 'static,
+ {
+ use tracing::Instrument;
+ let span = tracing::trace_span!("test_span");
+ let task = task.instrument(span);
+ let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule);
+ (task.into_notified(), handle)
+ }
+
+ #[cfg(not(all(tokio_unstable, feature = "tracing")))]
+ pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>)
+ where
+ T: std::future::Future + Send + 'static,
+ T::Output: Send + 'static,
+ {
+ let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule);
+ (task.into_notified(), handle)
+ }
+}
+
+cfg_loom! {
+ mod loom_basic_scheduler;
+ mod loom_blocking;
+ mod loom_local;
+ mod loom_oneshot;
+ mod loom_pool;
+ mod loom_queue;
+ mod loom_shutdown_join;
+ mod loom_join_set;
+}
+
+cfg_not_loom! {
+ mod queue;
+
+ #[cfg(not(miri))]
+ mod task_combinations;
+
+ #[cfg(miri)]
+ mod task;
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/queue.rs b/third_party/rust/tokio/src/runtime/tests/queue.rs
new file mode 100644
index 0000000000..0fd1e0c6d9
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/queue.rs
@@ -0,0 +1,248 @@
+use crate::runtime::queue;
+use crate::runtime::task::{self, Inject, Schedule, Task};
+use crate::runtime::MetricsBatch;
+
+use std::thread;
+use std::time::Duration;
+
+#[allow(unused)]
+macro_rules! assert_metrics {
+ ($metrics:ident, $field:ident == $v:expr) => {{
+ use crate::runtime::WorkerMetrics;
+ use std::sync::atomic::Ordering::Relaxed;
+
+ let worker = WorkerMetrics::new();
+ $metrics.submit(&worker);
+
+ let expect = $v;
+ let actual = worker.$field.load(Relaxed);
+
+ assert!(actual == expect, "expect = {}; actual = {}", expect, actual)
+ }};
+}
+
+#[test]
+fn fits_256() {
+ let (_, mut local) = queue::local();
+ let inject = Inject::new();
+ let mut metrics = MetricsBatch::new();
+
+ for _ in 0..256 {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ cfg_metrics! {
+ assert_metrics!(metrics, overflow_count == 0);
+ }
+
+ assert!(inject.pop().is_none());
+
+ while local.pop().is_some() {}
+}
+
+#[test]
+fn overflow() {
+ let (_, mut local) = queue::local();
+ let inject = Inject::new();
+ let mut metrics = MetricsBatch::new();
+
+ for _ in 0..257 {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ cfg_metrics! {
+ assert_metrics!(metrics, overflow_count == 1);
+ }
+
+ let mut n = 0;
+
+ while inject.pop().is_some() {
+ n += 1;
+ }
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ assert_eq!(n, 257);
+}
+
+#[test]
+fn steal_batch() {
+ let mut metrics = MetricsBatch::new();
+
+ let (steal1, mut local1) = queue::local();
+ let (_, mut local2) = queue::local();
+ let inject = Inject::new();
+
+ for _ in 0..4 {
+ let (task, _) = super::unowned(async {});
+ local1.push_back(task, &inject, &mut metrics);
+ }
+
+ assert!(steal1.steal_into(&mut local2, &mut metrics).is_some());
+
+ cfg_metrics! {
+ assert_metrics!(metrics, steal_count == 2);
+ }
+
+ for _ in 0..1 {
+ assert!(local2.pop().is_some());
+ }
+
+ assert!(local2.pop().is_none());
+
+ for _ in 0..2 {
+ assert!(local1.pop().is_some());
+ }
+
+ assert!(local1.pop().is_none());
+}
+
+const fn normal_or_miri(normal: usize, miri: usize) -> usize {
+ if cfg!(miri) {
+ miri
+ } else {
+ normal
+ }
+}
+
+#[test]
+fn stress1() {
+ const NUM_ITER: usize = 1;
+ const NUM_STEAL: usize = normal_or_miri(1_000, 10);
+ const NUM_LOCAL: usize = normal_or_miri(1_000, 10);
+ const NUM_PUSH: usize = normal_or_miri(500, 10);
+ const NUM_POP: usize = normal_or_miri(250, 10);
+
+ let mut metrics = MetricsBatch::new();
+
+ for _ in 0..NUM_ITER {
+ let (steal, mut local) = queue::local();
+ let inject = Inject::new();
+
+ let th = thread::spawn(move || {
+ let mut metrics = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+ let mut n = 0;
+
+ for _ in 0..NUM_STEAL {
+ if steal.steal_into(&mut local, &mut metrics).is_some() {
+ n += 1;
+ }
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ thread::yield_now();
+ }
+
+ cfg_metrics! {
+ assert_metrics!(metrics, steal_count == n as _);
+ }
+
+ n
+ });
+
+ let mut n = 0;
+
+ for _ in 0..NUM_LOCAL {
+ for _ in 0..NUM_PUSH {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+ }
+
+ for _ in 0..NUM_POP {
+ if local.pop().is_some() {
+ n += 1;
+ } else {
+ break;
+ }
+ }
+ }
+
+ while inject.pop().is_some() {
+ n += 1;
+ }
+
+ n += th.join().unwrap();
+
+ assert_eq!(n, NUM_LOCAL * NUM_PUSH);
+ }
+}
+
+#[test]
+fn stress2() {
+ const NUM_ITER: usize = 1;
+ const NUM_TASKS: usize = normal_or_miri(1_000_000, 50);
+ const NUM_STEAL: usize = normal_or_miri(1_000, 10);
+
+ let mut metrics = MetricsBatch::new();
+
+ for _ in 0..NUM_ITER {
+ let (steal, mut local) = queue::local();
+ let inject = Inject::new();
+
+ let th = thread::spawn(move || {
+ let mut stats = MetricsBatch::new();
+ let (_, mut local) = queue::local();
+ let mut n = 0;
+
+ for _ in 0..NUM_STEAL {
+ if steal.steal_into(&mut local, &mut stats).is_some() {
+ n += 1;
+ }
+
+ while local.pop().is_some() {
+ n += 1;
+ }
+
+ thread::sleep(Duration::from_micros(10));
+ }
+
+ n
+ });
+
+ let mut num_pop = 0;
+
+ for i in 0..NUM_TASKS {
+ let (task, _) = super::unowned(async {});
+ local.push_back(task, &inject, &mut metrics);
+
+ if i % 128 == 0 && local.pop().is_some() {
+ num_pop += 1;
+ }
+
+ while inject.pop().is_some() {
+ num_pop += 1;
+ }
+ }
+
+ num_pop += th.join().unwrap();
+
+ while local.pop().is_some() {
+ num_pop += 1;
+ }
+
+ while inject.pop().is_some() {
+ num_pop += 1;
+ }
+
+ assert_eq!(num_pop, NUM_TASKS);
+ }
+}
+
+struct Runtime;
+
+impl Schedule for Runtime {
+ fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> {
+ None
+ }
+
+ fn schedule(&self, _task: task::Notified<Self>) {
+ unreachable!();
+ }
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/task.rs b/third_party/rust/tokio/src/runtime/tests/task.rs
new file mode 100644
index 0000000000..04e1b56e77
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/task.rs
@@ -0,0 +1,288 @@
+use crate::runtime::blocking::NoopSchedule;
+use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task};
+use crate::util::TryLock;
+
+use std::collections::VecDeque;
+use std::future::Future;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
+struct AssertDropHandle {
+ is_dropped: Arc<AtomicBool>,
+}
+impl AssertDropHandle {
+ #[track_caller]
+ fn assert_dropped(&self) {
+ assert!(self.is_dropped.load(Ordering::SeqCst));
+ }
+
+ #[track_caller]
+ fn assert_not_dropped(&self) {
+ assert!(!self.is_dropped.load(Ordering::SeqCst));
+ }
+}
+
+struct AssertDrop {
+ is_dropped: Arc<AtomicBool>,
+}
+impl AssertDrop {
+ fn new() -> (Self, AssertDropHandle) {
+ let shared = Arc::new(AtomicBool::new(false));
+ (
+ AssertDrop {
+ is_dropped: shared.clone(),
+ },
+ AssertDropHandle {
+ is_dropped: shared.clone(),
+ },
+ )
+ }
+}
+impl Drop for AssertDrop {
+ fn drop(&mut self) {
+ self.is_dropped.store(true, Ordering::SeqCst);
+ }
+}
+
+// A Notified does not shut down on drop, but it is dropped once the ref-count
+// hits zero.
+#[test]
+fn create_drop1() {
+ let (ad, handle) = AssertDrop::new();
+ let (notified, join) = unowned(
+ async {
+ drop(ad);
+ unreachable!()
+ },
+ NoopSchedule,
+ );
+ drop(notified);
+ handle.assert_not_dropped();
+ drop(join);
+ handle.assert_dropped();
+}
+
+#[test]
+fn create_drop2() {
+ let (ad, handle) = AssertDrop::new();
+ let (notified, join) = unowned(
+ async {
+ drop(ad);
+ unreachable!()
+ },
+ NoopSchedule,
+ );
+ drop(join);
+ handle.assert_not_dropped();
+ drop(notified);
+ handle.assert_dropped();
+}
+
+// Shutting down through Notified works
+#[test]
+fn create_shutdown1() {
+ let (ad, handle) = AssertDrop::new();
+ let (notified, join) = unowned(
+ async {
+ drop(ad);
+ unreachable!()
+ },
+ NoopSchedule,
+ );
+ drop(join);
+ handle.assert_not_dropped();
+ notified.shutdown();
+ handle.assert_dropped();
+}
+
+#[test]
+fn create_shutdown2() {
+ let (ad, handle) = AssertDrop::new();
+ let (notified, join) = unowned(
+ async {
+ drop(ad);
+ unreachable!()
+ },
+ NoopSchedule,
+ );
+ handle.assert_not_dropped();
+ notified.shutdown();
+ handle.assert_dropped();
+ drop(join);
+}
+
+#[test]
+fn unowned_poll() {
+ let (task, _) = unowned(async {}, NoopSchedule);
+ task.run();
+}
+
+#[test]
+fn schedule() {
+ with(|rt| {
+ rt.spawn(async {
+ crate::task::yield_now().await;
+ });
+
+ assert_eq!(2, rt.tick());
+ rt.shutdown();
+ })
+}
+
+#[test]
+fn shutdown() {
+ with(|rt| {
+ rt.spawn(async {
+ loop {
+ crate::task::yield_now().await;
+ }
+ });
+
+ rt.tick_max(1);
+
+ rt.shutdown();
+ })
+}
+
+#[test]
+fn shutdown_immediately() {
+ with(|rt| {
+ rt.spawn(async {
+ loop {
+ crate::task::yield_now().await;
+ }
+ });
+
+ rt.shutdown();
+ })
+}
+
+#[test]
+fn spawn_during_shutdown() {
+ static DID_SPAWN: AtomicBool = AtomicBool::new(false);
+
+ struct SpawnOnDrop(Runtime);
+ impl Drop for SpawnOnDrop {
+ fn drop(&mut self) {
+ DID_SPAWN.store(true, Ordering::SeqCst);
+ self.0.spawn(async {});
+ }
+ }
+
+ with(|rt| {
+ let rt2 = rt.clone();
+ rt.spawn(async move {
+ let _spawn_on_drop = SpawnOnDrop(rt2);
+
+ loop {
+ crate::task::yield_now().await;
+ }
+ });
+
+ rt.tick_max(1);
+ rt.shutdown();
+ });
+
+ assert!(DID_SPAWN.load(Ordering::SeqCst));
+}
+
+fn with(f: impl FnOnce(Runtime)) {
+ struct Reset;
+
+ impl Drop for Reset {
+ fn drop(&mut self) {
+ let _rt = CURRENT.try_lock().unwrap().take();
+ }
+ }
+
+ let _reset = Reset;
+
+ let rt = Runtime(Arc::new(Inner {
+ owned: OwnedTasks::new(),
+ core: TryLock::new(Core {
+ queue: VecDeque::new(),
+ }),
+ }));
+
+ *CURRENT.try_lock().unwrap() = Some(rt.clone());
+ f(rt)
+}
+
+#[derive(Clone)]
+struct Runtime(Arc<Inner>);
+
+struct Inner {
+ core: TryLock<Core>,
+ owned: OwnedTasks<Runtime>,
+}
+
+struct Core {
+ queue: VecDeque<task::Notified<Runtime>>,
+}
+
+static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None);
+
+impl Runtime {
+ fn spawn<T>(&self, future: T) -> JoinHandle<T::Output>
+ where
+ T: 'static + Send + Future,
+ T::Output: 'static + Send,
+ {
+ let (handle, notified) = self.0.owned.bind(future, self.clone());
+
+ if let Some(notified) = notified {
+ self.schedule(notified);
+ }
+
+ handle
+ }
+
+ fn tick(&self) -> usize {
+ self.tick_max(usize::MAX)
+ }
+
+ fn tick_max(&self, max: usize) -> usize {
+ let mut n = 0;
+
+ while !self.is_empty() && n < max {
+ let task = self.next_task();
+ n += 1;
+ let task = self.0.owned.assert_owner(task);
+ task.run();
+ }
+
+ n
+ }
+
+ fn is_empty(&self) -> bool {
+ self.0.core.try_lock().unwrap().queue.is_empty()
+ }
+
+ fn next_task(&self) -> task::Notified<Runtime> {
+ self.0.core.try_lock().unwrap().queue.pop_front().unwrap()
+ }
+
+ fn shutdown(&self) {
+ let mut core = self.0.core.try_lock().unwrap();
+
+ self.0.owned.close_and_shutdown_all();
+
+ while let Some(task) = core.queue.pop_back() {
+ drop(task);
+ }
+
+ drop(core);
+
+ assert!(self.0.owned.is_empty());
+ }
+}
+
+impl Schedule for Runtime {
+ fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
+ self.0.owned.remove(task)
+ }
+
+ fn schedule(&self, task: task::Notified<Self>) {
+ self.0.core.try_lock().unwrap().queue.push_back(task);
+ }
+}
diff --git a/third_party/rust/tokio/src/runtime/tests/task_combinations.rs b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs
new file mode 100644
index 0000000000..76ce2330c2
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs
@@ -0,0 +1,380 @@
+use std::future::Future;
+use std::panic;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use crate::runtime::Builder;
+use crate::sync::oneshot;
+use crate::task::JoinHandle;
+
+use futures::future::FutureExt;
+
+// Enums for each option in the combinations being tested
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiRuntime {
+ CurrentThread,
+ Multi1,
+ Multi2,
+}
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiLocalSet {
+ Yes,
+ No,
+}
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiTask {
+ PanicOnRun,
+ PanicOnDrop,
+ PanicOnRunAndDrop,
+ NoPanic,
+}
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiOutput {
+ PanicOnDrop,
+ NoPanic,
+}
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiJoinInterest {
+ Polled,
+ NotPolled,
+}
+#[allow(clippy::enum_variant_names)] // we aren't using glob imports
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiJoinHandle {
+ DropImmediately = 1,
+ DropFirstPoll = 2,
+ DropAfterNoConsume = 3,
+ DropAfterConsume = 4,
+}
+#[derive(Copy, Clone, Debug, PartialEq)]
+enum CombiAbort {
+ NotAborted = 0,
+ AbortedImmediately = 1,
+ AbortedFirstPoll = 2,
+ AbortedAfterFinish = 3,
+ AbortedAfterConsumeOutput = 4,
+}
+
+#[test]
+fn test_combinations() {
+ let mut rt = &[
+ CombiRuntime::CurrentThread,
+ CombiRuntime::Multi1,
+ CombiRuntime::Multi2,
+ ][..];
+
+ if cfg!(miri) {
+ rt = &[CombiRuntime::CurrentThread];
+ }
+
+ let ls = [CombiLocalSet::Yes, CombiLocalSet::No];
+ let task = [
+ CombiTask::NoPanic,
+ CombiTask::PanicOnRun,
+ CombiTask::PanicOnDrop,
+ CombiTask::PanicOnRunAndDrop,
+ ];
+ let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop];
+ let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled];
+ let jh = [
+ CombiJoinHandle::DropImmediately,
+ CombiJoinHandle::DropFirstPoll,
+ CombiJoinHandle::DropAfterNoConsume,
+ CombiJoinHandle::DropAfterConsume,
+ ];
+ let abort = [
+ CombiAbort::NotAborted,
+ CombiAbort::AbortedImmediately,
+ CombiAbort::AbortedFirstPoll,
+ CombiAbort::AbortedAfterFinish,
+ CombiAbort::AbortedAfterConsumeOutput,
+ ];
+
+ for rt in rt.iter().copied() {
+ for ls in ls.iter().copied() {
+ for task in task.iter().copied() {
+ for output in output.iter().copied() {
+ for ji in ji.iter().copied() {
+ for jh in jh.iter().copied() {
+ for abort in abort.iter().copied() {
+ test_combination(rt, ls, task, output, ji, jh, abort);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+fn test_combination(
+ rt: CombiRuntime,
+ ls: CombiLocalSet,
+ task: CombiTask,
+ output: CombiOutput,
+ ji: CombiJoinInterest,
+ jh: CombiJoinHandle,
+ abort: CombiAbort,
+) {
+ if (jh as usize) < (abort as usize) {
+ // drop before abort not possible
+ return;
+ }
+ if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) {
+ // this causes double panic
+ return;
+ }
+ if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) {
+ // this causes double panic
+ return;
+ }
+
+ println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort);
+
+ // A runtime optionally with a LocalSet
+ struct Rt {
+ rt: crate::runtime::Runtime,
+ ls: Option<crate::task::LocalSet>,
+ }
+ impl Rt {
+ fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self {
+ let rt = match rt {
+ CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(),
+ CombiRuntime::Multi1 => Builder::new_multi_thread()
+ .worker_threads(1)
+ .build()
+ .unwrap(),
+ CombiRuntime::Multi2 => Builder::new_multi_thread()
+ .worker_threads(2)
+ .build()
+ .unwrap(),
+ };
+
+ let ls = match ls {
+ CombiLocalSet::Yes => Some(crate::task::LocalSet::new()),
+ CombiLocalSet::No => None,
+ };
+
+ Self { rt, ls }
+ }
+ fn block_on<T>(&self, task: T) -> T::Output
+ where
+ T: Future,
+ {
+ match &self.ls {
+ Some(ls) => ls.block_on(&self.rt, task),
+ None => self.rt.block_on(task),
+ }
+ }
+ fn spawn<T>(&self, task: T) -> JoinHandle<T::Output>
+ where
+ T: Future + Send + 'static,
+ T::Output: Send + 'static,
+ {
+ match &self.ls {
+ Some(ls) => ls.spawn_local(task),
+ None => self.rt.spawn(task),
+ }
+ }
+ }
+
+ // The type used for the output of the future
+ struct Output {
+ panic_on_drop: bool,
+ on_drop: Option<oneshot::Sender<()>>,
+ }
+ impl Output {
+ fn disarm(&mut self) {
+ self.panic_on_drop = false;
+ }
+ }
+ impl Drop for Output {
+ fn drop(&mut self) {
+ let _ = self.on_drop.take().unwrap().send(());
+ if self.panic_on_drop {
+ panic!("Panicking in Output");
+ }
+ }
+ }
+
+ // A wrapper around the future that is spawned
+ struct FutWrapper<F> {
+ inner: F,
+ on_drop: Option<oneshot::Sender<()>>,
+ panic_on_drop: bool,
+ }
+ impl<F: Future> Future for FutWrapper<F> {
+ type Output = F::Output;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
+ unsafe {
+ let me = Pin::into_inner_unchecked(self);
+ let inner = Pin::new_unchecked(&mut me.inner);
+ inner.poll(cx)
+ }
+ }
+ }
+ impl<F> Drop for FutWrapper<F> {
+ fn drop(&mut self) {
+ let _: Result<(), ()> = self.on_drop.take().unwrap().send(());
+ if self.panic_on_drop {
+ panic!("Panicking in FutWrapper");
+ }
+ }
+ }
+
+ // The channels passed to the task
+ struct Signals {
+ on_first_poll: Option<oneshot::Sender<()>>,
+ wait_complete: Option<oneshot::Receiver<()>>,
+ on_output_drop: Option<oneshot::Sender<()>>,
+ }
+
+ // The task we will spawn
+ async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output {
+ // Signal that we have been polled once
+ let _ = signal.on_first_poll.take().unwrap().send(());
+
+ // Wait for a signal, then complete the future
+ let _ = signal.wait_complete.take().unwrap().await;
+
+ // If the task gets past wait_complete without yielding, then aborts
+ // may not be caught without this yield_now.
+ crate::task::yield_now().await;
+
+ if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop {
+ panic!("Panicking in my_task on {:?}", std::thread::current().id());
+ }
+
+ Output {
+ panic_on_drop: out == CombiOutput::PanicOnDrop,
+ on_drop: signal.on_output_drop.take(),
+ }
+ }
+
+ let rt = Rt::new(rt, ls);
+
+ let (on_first_poll, wait_first_poll) = oneshot::channel();
+ let (on_complete, wait_complete) = oneshot::channel();
+ let (on_future_drop, wait_future_drop) = oneshot::channel();
+ let (on_output_drop, wait_output_drop) = oneshot::channel();
+ let signal = Signals {
+ on_first_poll: Some(on_first_poll),
+ wait_complete: Some(wait_complete),
+ on_output_drop: Some(on_output_drop),
+ };
+
+ // === Spawn task ===
+ let mut handle = Some(rt.spawn(FutWrapper {
+ inner: my_task(signal, task, output),
+ on_drop: Some(on_future_drop),
+ panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop,
+ }));
+
+ // Keep track of whether the task has been killed with an abort
+ let mut aborted = false;
+
+ // If we want to poll the JoinHandle, do it now
+ if ji == CombiJoinInterest::Polled {
+ assert!(
+ handle.as_mut().unwrap().now_or_never().is_none(),
+ "Polling handle succeeded"
+ );
+ }
+
+ if abort == CombiAbort::AbortedImmediately {
+ handle.as_mut().unwrap().abort();
+ aborted = true;
+ }
+ if jh == CombiJoinHandle::DropImmediately {
+ drop(handle.take().unwrap());
+ }
+
+ // === Wait for first poll ===
+ let got_polled = rt.block_on(wait_first_poll).is_ok();
+ if !got_polled {
+ // it's possible that we are aborted but still got polled
+ assert!(
+ aborted,
+ "Task completed without ever being polled but was not aborted."
+ );
+ }
+
+ if abort == CombiAbort::AbortedFirstPoll {
+ handle.as_mut().unwrap().abort();
+ aborted = true;
+ }
+ if jh == CombiJoinHandle::DropFirstPoll {
+ drop(handle.take().unwrap());
+ }
+
+ // Signal the future that it can return now
+ let _ = on_complete.send(());
+ // === Wait for future to be dropped ===
+ assert!(
+ rt.block_on(wait_future_drop).is_ok(),
+ "The future should always be dropped."
+ );
+
+ if abort == CombiAbort::AbortedAfterFinish {
+ // Don't set aborted to true here as the task already finished
+ handle.as_mut().unwrap().abort();
+ }
+ if jh == CombiJoinHandle::DropAfterNoConsume {
+ // The runtime will usually have dropped every ref-count at this point,
+ // in which case dropping the JoinHandle drops the output.
+ //
+ // (But it might race and still hold a ref-count)
+ let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
+ drop(handle.take().unwrap());
+ }));
+ if panic.is_err() {
+ assert!(
+ (output == CombiOutput::PanicOnDrop)
+ && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop))
+ && !aborted,
+ "Dropping JoinHandle shouldn't panic here"
+ );
+ }
+ }
+
+ // Check whether we drop after consuming the output
+ if jh == CombiJoinHandle::DropAfterConsume {
+ // Using as_mut() to not immediately drop the handle
+ let result = rt.block_on(handle.as_mut().unwrap());
+
+ match result {
+ Ok(mut output) => {
+ // Don't panic here.
+ output.disarm();
+ assert!(!aborted, "Task was aborted but returned output");
+ }
+ Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"),
+ Err(err) if err.is_panic() => {
+ assert!(
+ (task == CombiTask::PanicOnRun)
+ || (task == CombiTask::PanicOnDrop)
+ || (task == CombiTask::PanicOnRunAndDrop)
+ || (output == CombiOutput::PanicOnDrop),
+ "Panic but nothing should panic"
+ );
+ }
+ _ => unreachable!(),
+ }
+
+ let handle = handle.take().unwrap();
+ if abort == CombiAbort::AbortedAfterConsumeOutput {
+ handle.abort();
+ }
+ drop(handle);
+ }
+
+ // The output should have been dropped now. Check whether the output
+ // object was created at all.
+ let output_created = rt.block_on(wait_output_drop).is_ok();
+ assert_eq!(
+ output_created,
+ (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted,
+ "Creation of output object"
+ );
+}