diff options
Diffstat (limited to 'third_party/rust/tokio/src/runtime/tests')
12 files changed, 2032 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs new file mode 100644 index 0000000000..a772603f71 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs @@ -0,0 +1,142 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::loom::thread; +use crate::runtime::{Builder, Runtime}; +use crate::sync::oneshot::{self, Receiver}; +use crate::task; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::task::{Context, Poll}; + +fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) { + let (tx, rx) = oneshot::channel(); + let num_polls = Arc::new(AtomicUsize::new(0)); + rt.spawn(async move { + for _ in 0..12 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + rt.block_on(async { + BlockedFuture { + rx, + num_polls: num_polls.clone(), + } + .await; + }); + + let polls = num_polls.load(Acquire); + assert!(polls <= at_most_polls); +} + +#[test] +fn block_on_num_polls() { + loom::model(|| { + // we expect at most 4 number of polls because there are three points at + // which we poll the future and an opportunity for a false-positive.. At + // any of these points it can be ready: + // + // - when we fail to steal the parker and we block on a notification + // that it is available. + // + // - when we steal the parker and we schedule the future + // + // - when the future is woken up and we have ran the max number of tasks + // for the current tick or there are no more tasks to run. + // + // - a thread is notified that the parker is available but a third + // thread acquires it before the notified thread can. + // + let at_most = 4; + + let rt1 = Arc::new(Builder::new_current_thread().build().unwrap()); + let rt2 = rt1.clone(); + let rt3 = rt1.clone(); + + let th1 = thread::spawn(move || assert_at_most_num_polls(rt1, at_most)); + let th2 = thread::spawn(move || assert_at_most_num_polls(rt2, at_most)); + let th3 = thread::spawn(move || assert_at_most_num_polls(rt3, at_most)); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); + }); +} + +#[test] +fn assert_no_unnecessary_polls() { + loom::model(|| { + // // After we poll outer future, woken should reset to false + let rt = Builder::new_current_thread().build().unwrap(); + let (tx, rx) = oneshot::channel(); + let pending_cnt = Arc::new(AtomicUsize::new(0)); + + rt.spawn(async move { + for _ in 0..24 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + let pending_cnt_clone = pending_cnt.clone(); + rt.block_on(async move { + // use task::yield_now() to ensure woken set to true + // ResetFuture will be polled at most once + // Here comes two cases + // 1. recv no message from channel, ResetFuture will be polled + // but get Pending and we record ResetFuture.pending_cnt ++. + // Then when message arrive, ResetFuture returns Ready. So we + // expect ResetFuture.pending_cnt = 1 + // 2. recv message from channel, ResetFuture returns Ready immediately. + // We expect ResetFuture.pending_cnt = 0 + task::yield_now().await; + ResetFuture { + rx, + pending_cnt: pending_cnt_clone, + } + .await; + }); + + let pending_cnt = pending_cnt.load(Acquire); + assert!(pending_cnt <= 1); + }); +} + +struct BlockedFuture { + rx: Receiver<()>, + num_polls: Arc<AtomicUsize>, +} + +impl Future for BlockedFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.num_polls.fetch_add(1, Release); + + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => Poll::Pending, + _ => Poll::Ready(()), + } + } +} + +struct ResetFuture { + rx: Receiver<()>, + pending_cnt: Arc<AtomicUsize>, +} + +impl Future for ResetFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => { + self.pending_cnt.fetch_add(1, Release); + Poll::Pending + } + _ => Poll::Ready(()), + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs new file mode 100644 index 0000000000..89de85e436 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs @@ -0,0 +1,81 @@ +use crate::runtime::{self, Runtime}; + +use std::sync::Arc; + +#[test] +fn blocking_shutdown() { + loom::model(|| { + let v = Arc::new(()); + + let rt = mk_runtime(1); + { + let _enter = rt.enter(); + for _ in 0..2 { + let v = v.clone(); + crate::task::spawn_blocking(move || { + assert!(1 < Arc::strong_count(&v)); + }); + } + } + + drop(rt); + assert_eq!(1, Arc::strong_count(&v)); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_always_run() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + + let (tx, rx) = loom_oneshot::channel(); + let _enter = rt.enter(); + runtime::spawn_blocking(|| {}); + runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }) + .unwrap(); + + drop(rt); + + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + let handle = rt.handle().clone(); + + // Drop the runtime in a different thread + { + loom::thread::spawn(move || { + drop(rt); + }); + } + + let _enter = handle.enter(); + let (tx, rx) = loom_oneshot::channel(); + let handle = runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }); + + // handle.is_some() means that `spawn_mandatory_blocking` + // promised us to run the blocking task + if handle.is_some() { + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + } + }); +} + +fn mk_runtime(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs new file mode 100644 index 0000000000..e87ddb0140 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs @@ -0,0 +1,82 @@ +use crate::runtime::Builder; +use crate::task::JoinSet; + +#[test] +fn test_join_set() { + loom::model(|| { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let mut set = JoinSet::new(); + + rt.block_on(async { + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + }); + + drop(set); + drop(rt); + }); +} + +#[test] +fn abort_all_during_completion() { + use std::sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }; + + // These booleans assert that at least one execution had the task complete first, and that at + // least one execution had the task be cancelled before it completed. + let complete_happened = Arc::new(AtomicBool::new(false)); + let cancel_happened = Arc::new(AtomicBool::new(false)); + + { + let complete_happened = complete_happened.clone(); + let cancel_happened = cancel_happened.clone(); + loom::model(move || { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + + let mut set = JoinSet::new(); + + rt.block_on(async { + set.spawn(async { () }); + set.abort_all(); + + match set.join_one().await { + Ok(Some(())) => complete_happened.store(true, SeqCst), + Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst), + Err(err) => panic!("fail: {}", err), + Ok(None) => { + unreachable!("Aborting the task does not remove it from the JoinSet.") + } + } + + assert!(matches!(set.join_one().await, Ok(None))); + }); + + drop(set); + drop(rt); + }); + } + + assert!(complete_happened.load(SeqCst)); + assert!(cancel_happened.load(SeqCst)); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_local.rs b/third_party/rust/tokio/src/runtime/tests/loom_local.rs new file mode 100644 index 0000000000..d9a07a45f0 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs new file mode 100644 index 0000000000..87eb638642 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs @@ -0,0 +1,48 @@ +use crate::loom::sync::{Arc, Mutex}; +use loom::sync::Notify; + +pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) { + let inner = Arc::new(Inner { + notify: Notify::new(), + value: Mutex::new(None), + }); + + let tx = Sender { + inner: inner.clone(), + }; + let rx = Receiver { inner }; + + (tx, rx) +} + +pub(crate) struct Sender<T> { + inner: Arc<Inner<T>>, +} + +pub(crate) struct Receiver<T> { + inner: Arc<Inner<T>>, +} + +struct Inner<T> { + notify: Notify, + value: Mutex<Option<T>>, +} + +impl<T> Sender<T> { + pub(crate) fn send(self, value: T) { + *self.inner.value.lock() = Some(value); + self.inner.notify.notify(); + } +} + +impl<T> Receiver<T> { + pub(crate) fn recv(self) -> T { + loop { + if let Some(v) = self.inner.value.lock().take() { + return v; + } + + self.inner.notify.wait(); + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_pool.rs b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs new file mode 100644 index 0000000000..b3ecd43124 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs @@ -0,0 +1,430 @@ +/// Full runtime loom tests. These are heavy tests and take significant time to +/// run on CI. +/// +/// Use `LOOM_MAX_PREEMPTIONS=1` to do a "quick" run as a smoke test. +/// +/// In order to speed up the C +use crate::future::poll_fn; +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::{self, Runtime}; +use crate::{spawn, task}; +use tokio_test::assert_ok; + +use loom::sync::atomic::{AtomicBool, AtomicUsize}; +use loom::sync::Arc; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Relaxed, SeqCst}; +use std::task::{Context, Poll}; + +mod atomic_take { + use loom::sync::atomic::AtomicBool; + use std::mem::MaybeUninit; + use std::sync::atomic::Ordering::SeqCst; + + pub(super) struct AtomicTake<T> { + inner: MaybeUninit<T>, + taken: AtomicBool, + } + + impl<T> AtomicTake<T> { + pub(super) fn new(value: T) -> Self { + Self { + inner: MaybeUninit::new(value), + taken: AtomicBool::new(false), + } + } + + pub(super) fn take(&self) -> Option<T> { + // safety: Only one thread will see the boolean change from false + // to true, so that thread is able to take the value. + match self.taken.fetch_or(true, SeqCst) { + false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) }, + true => None, + } + } + } + + impl<T> Drop for AtomicTake<T> { + fn drop(&mut self) { + drop(self.take()); + } + } +} + +#[derive(Clone)] +struct AtomicOneshot<T> { + value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>, +} +impl<T> AtomicOneshot<T> { + fn new(sender: oneshot::Sender<T>) -> Self { + Self { + value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)), + } + } + + fn assert_send(&self, value: T) { + self.value.take().unwrap().send(value); + } +} + +/// Tests are divided into groups to make the runs faster on CI. +mod group_a { + use super::*; + + #[test] + fn racy_shutdown() { + loom::model(|| { + let pool = mk_pool(1); + + // here's the case we want to exercise: + // + // a worker that still has tasks in its local queue gets sent to the blocking pool (due to + // block_in_place). the blocking pool is shut down, so drops the worker. the worker's + // shutdown method never gets run. + // + // we do this by spawning two tasks on one worker, the first of which does block_in_place, + // and then immediately drop the pool. + + pool.spawn(track(async { + crate::task::block_in_place(|| {}); + })); + pool.spawn(track(async {})); + drop(pool); + }); + } + + #[test] + fn pool_multi_spawn() { + loom::model(|| { + let pool = mk_pool(2); + let c1 = Arc::new(AtomicUsize::new(0)); + + let (tx, rx) = oneshot::channel(); + let tx1 = AtomicOneshot::new(tx); + + // Spawn a task + let c2 = c1.clone(); + let tx2 = tx1.clone(); + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c1.fetch_add(1, Relaxed) { + tx1.assert_send(()); + } + })); + })); + + // Spawn a second task + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c2.fetch_add(1, Relaxed) { + tx2.assert_send(()); + } + })); + })); + + rx.recv(); + }); + } + + fn only_blocking_inner(first_pending: bool) { + loom::model(move || { + let pool = mk_pool(1); + let (block_tx, block_rx) = oneshot::channel(); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + block_rx.recv(); + drop(pool); + }); + } + + #[test] + fn only_blocking_without_pending() { + only_blocking_inner(false) + } + + #[test] + fn only_blocking_with_pending() { + only_blocking_inner(true) + } +} + +mod group_b { + use super::*; + + fn blocking_and_regular_inner(first_pending: bool) { + const NUM: usize = 3; + loom::model(move || { + let pool = mk_pool(1); + let cnt = Arc::new(AtomicUsize::new(0)); + + let (block_tx, block_rx) = oneshot::channel(); + let (done_tx, done_rx) = oneshot::channel(); + let done_tx = AtomicOneshot::new(done_tx); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + for _ in 0..NUM { + let cnt = cnt.clone(); + let done_tx = done_tx.clone(); + + pool.spawn(track(async move { + if NUM == cnt.fetch_add(1, Relaxed) + 1 { + done_tx.assert_send(()); + } + })); + } + + done_rx.recv(); + block_rx.recv(); + + drop(pool); + }); + } + + #[test] + fn blocking_and_regular() { + blocking_and_regular_inner(false); + } + + #[test] + fn blocking_and_regular_with_pending() { + blocking_and_regular_inner(true); + } + + #[test] + fn join_output() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async { + let t = crate::spawn(track(async { "hello" })); + + let out = assert_ok!(t.await); + assert_eq!("hello", out.into_inner()); + }); + }); + } + + #[test] + fn poll_drop_handle_then_drop() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async move { + let mut t = crate::spawn(track(async { "hello" })); + + poll_fn(|cx| { + let _ = Pin::new(&mut t).poll(cx); + Poll::Ready(()) + }) + .await; + }); + }) + } + + #[test] + fn complete_block_on_under_load() { + loom::model(|| { + let pool = mk_pool(1); + + pool.block_on(async { + // Trigger a re-schedule + crate::spawn(track(async { + for _ in 0..2 { + task::yield_now().await; + } + })); + + gated2(true).await + }); + }); + } + + #[test] + fn shutdown_with_notification() { + use crate::sync::oneshot; + + loom::model(|| { + let rt = mk_pool(2); + let (done_tx, done_rx) = oneshot::channel::<()>(); + + rt.spawn(track(async move { + let (tx, rx) = oneshot::channel::<()>(); + + crate::spawn(async move { + crate::task::spawn_blocking(move || { + let _ = tx.send(()); + }); + + let _ = done_rx.await; + }); + + let _ = rx.await; + + let _ = done_tx.send(()); + })); + }); + } +} + +mod group_c { + use super::*; + + #[test] + fn pool_shutdown() { + loom::model(|| { + let pool = mk_pool(2); + + pool.spawn(track(async move { + gated2(true).await; + })); + + pool.spawn(track(async move { + gated2(false).await; + })); + + drop(pool); + }); + } +} + +mod group_d { + use super::*; + + #[test] + fn pool_multi_notify() { + loom::model(|| { + let pool = mk_pool(2); + + let c1 = Arc::new(AtomicUsize::new(0)); + + let (done_tx, done_rx) = oneshot::channel(); + let done_tx1 = AtomicOneshot::new(done_tx); + let done_tx2 = done_tx1.clone(); + + // Spawn a task + let c2 = c1.clone(); + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c1.fetch_add(1, Relaxed) { + done_tx1.assert_send(()); + } + })); + + // Spawn a second task + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c2.fetch_add(1, Relaxed) { + done_tx2.assert_send(()); + } + })); + + done_rx.recv(); + }); + } +} + +fn mk_pool(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} + +fn gated() -> impl Future<Output = &'static str> { + gated2(false) +} + +fn gated2(thread: bool) -> impl Future<Output = &'static str> { + use loom::thread; + use std::sync::Arc; + + let gate = Arc::new(AtomicBool::new(false)); + let mut fired = false; + + poll_fn(move |cx| { + if !fired { + let gate = gate.clone(); + let waker = cx.waker().clone(); + + if thread { + thread::spawn(move || { + gate.store(true, SeqCst); + waker.wake_by_ref(); + }); + } else { + spawn(track(async move { + gate.store(true, SeqCst); + waker.wake_by_ref(); + })); + } + + fired = true; + + return Poll::Pending; + } + + if gate.load(SeqCst) { + Poll::Ready("hello world") + } else { + Poll::Pending + } + }) +} + +fn track<T: Future>(f: T) -> Track<T> { + Track { + inner: f, + arc: Arc::new(()), + } +} + +pin_project! { + struct Track<T> { + #[pin] + inner: T, + // Arc is used to hook into loom's leak tracking. + arc: Arc<()>, + } +} + +impl<T> Track<T> { + fn into_inner(self) -> T { + self.inner + } +} + +impl<T: Future> Future for Track<T> { + type Output = Track<T::Output>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + Poll::Ready(Track { + inner: ready!(me.inner.poll(cx)), + arc: me.arc.clone(), + }) + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_queue.rs b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs new file mode 100644 index 0000000000..b5f78d7ebe --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs @@ -0,0 +1,208 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::Inject; +use crate::runtime::{queue, MetricsBatch}; + +use loom::thread; + +#[test] +fn basic() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..3 { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + } + + n + }); + + let mut n = 0; + + for _ in 0..2 { + for _ in 0..2 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + if local.pop().is_some() { + n += 1; + } + + // Push another task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + while local.pop().is_some() { + n += 1; + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(6, n); + }); +} + +#[test] +fn steal_overflow() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + n + }); + + let mut n = 0; + + // push a task, pop a task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if local.pop().is_some() { + n += 1; + } + + for _ in 0..6 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + n += th.join().unwrap(); + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + assert_eq!(7, n); + }); +} + +#[test] +fn multi_stealer() { + const NUM_TASKS: usize = 5; + + fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + + if steal.steal_into(&mut local, &mut metrics).is_none() { + return 0; + } + + let mut n = 1; + + while local.pop().is_some() { + n += 1; + } + + n + } + + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + // Push work + for _ in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + let th1 = { + let steal = steal.clone(); + thread::spawn(move || steal_tasks(steal)) + }; + + let th2 = thread::spawn(move || steal_tasks(steal)); + + let mut n = 0; + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + n += th1.join().unwrap(); + n += th2.join().unwrap(); + + assert_eq!(n, NUM_TASKS); + }); +} + +#[test] +fn chained_steal() { + loom::model(|| { + let mut metrics = MetricsBatch::new(); + let (s1, mut l1) = queue::local(); + let (s2, mut l2) = queue::local(); + let inject = Inject::new(); + + // Load up some tasks + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + l1.push_back(task, &inject, &mut metrics); + + let (task, _) = super::unowned(async {}); + l2.push_back(task, &inject, &mut metrics); + } + + // Spawn a task to steal from **our** queue + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + s1.steal_into(&mut local, &mut metrics); + + while local.pop().is_some() {} + }); + + // Drain our tasks, then attempt to steal + while l1.pop().is_some() {} + + s2.steal_into(&mut l1, &mut metrics); + + th.join().unwrap(); + + while l1.pop().is_some() {} + while l2.pop().is_some() {} + while inject.pop().is_some() {} + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs new file mode 100644 index 0000000000..6fbc4bfded --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs @@ -0,0 +1,28 @@ +use crate::runtime::{Builder, Handle}; + +#[test] +fn join_handle_cancel_on_shutdown() { + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + builder.check(|| { + use futures::future::FutureExt; + + let rt = Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(); + + let handle = rt.block_on(async move { Handle::current() }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/mod.rs b/third_party/rust/tokio/src/runtime/tests/mod.rs new file mode 100644 index 0000000000..4b49698a86 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/mod.rs @@ -0,0 +1,50 @@ +use self::unowned_wrapper::unowned; + +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + use tracing::Instrument; + let span = tracing::trace_span!("test_span"); + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } +} + +cfg_loom! { + mod loom_basic_scheduler; + mod loom_blocking; + mod loom_local; + mod loom_oneshot; + mod loom_pool; + mod loom_queue; + mod loom_shutdown_join; + mod loom_join_set; +} + +cfg_not_loom! { + mod queue; + + #[cfg(not(miri))] + mod task_combinations; + + #[cfg(miri)] + mod task; +} diff --git a/third_party/rust/tokio/src/runtime/tests/queue.rs b/third_party/rust/tokio/src/runtime/tests/queue.rs new file mode 100644 index 0000000000..0fd1e0c6d9 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/queue.rs @@ -0,0 +1,248 @@ +use crate::runtime::queue; +use crate::runtime::task::{self, Inject, Schedule, Task}; +use crate::runtime::MetricsBatch; + +use std::thread; +use std::time::Duration; + +#[allow(unused)] +macro_rules! assert_metrics { + ($metrics:ident, $field:ident == $v:expr) => {{ + use crate::runtime::WorkerMetrics; + use std::sync::atomic::Ordering::Relaxed; + + let worker = WorkerMetrics::new(); + $metrics.submit(&worker); + + let expect = $v; + let actual = worker.$field.load(Relaxed); + + assert!(actual == expect, "expect = {}; actual = {}", expect, actual) + }}; +} + +#[test] +fn fits_256() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..256 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 0); + } + + assert!(inject.pop().is_none()); + + while local.pop().is_some() {} +} + +#[test] +fn overflow() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..257 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 1); + } + + let mut n = 0; + + while inject.pop().is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + assert_eq!(n, 257); +} + +#[test] +fn steal_batch() { + let mut metrics = MetricsBatch::new(); + + let (steal1, mut local1) = queue::local(); + let (_, mut local2) = queue::local(); + let inject = Inject::new(); + + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + local1.push_back(task, &inject, &mut metrics); + } + + assert!(steal1.steal_into(&mut local2, &mut metrics).is_some()); + + cfg_metrics! { + assert_metrics!(metrics, steal_count == 2); + } + + for _ in 0..1 { + assert!(local2.pop().is_some()); + } + + assert!(local2.pop().is_none()); + + for _ in 0..2 { + assert!(local1.pop().is_some()); + } + + assert!(local1.pop().is_none()); +} + +const fn normal_or_miri(normal: usize, miri: usize) -> usize { + if cfg!(miri) { + miri + } else { + normal + } +} + +#[test] +fn stress1() { + const NUM_ITER: usize = 1; + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + const NUM_LOCAL: usize = normal_or_miri(1_000, 10); + const NUM_PUSH: usize = normal_or_miri(500, 10); + const NUM_POP: usize = normal_or_miri(250, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::yield_now(); + } + + cfg_metrics! { + assert_metrics!(metrics, steal_count == n as _); + } + + n + }); + + let mut n = 0; + + for _ in 0..NUM_LOCAL { + for _ in 0..NUM_PUSH { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + for _ in 0..NUM_POP { + if local.pop().is_some() { + n += 1; + } else { + break; + } + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(n, NUM_LOCAL * NUM_PUSH); + } +} + +#[test] +fn stress2() { + const NUM_ITER: usize = 1; + const NUM_TASKS: usize = normal_or_miri(1_000_000, 50); + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut stats = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut stats).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::sleep(Duration::from_micros(10)); + } + + n + }); + + let mut num_pop = 0; + + for i in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if i % 128 == 0 && local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + } + + num_pop += th.join().unwrap(); + + while local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + + assert_eq!(num_pop, NUM_TASKS); + } +} + +struct Runtime; + +impl Schedule for Runtime { + fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { + None + } + + fn schedule(&self, _task: task::Notified<Self>) { + unreachable!(); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task.rs b/third_party/rust/tokio/src/runtime/tests/task.rs new file mode 100644 index 0000000000..04e1b56e77 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task.rs @@ -0,0 +1,288 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::util::TryLock; + +use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +struct AssertDropHandle { + is_dropped: Arc<AtomicBool>, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc<AtomicBool>, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + +#[test] +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works +#[test] +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); +} + +#[test] +fn unowned_poll() { + let (task, _) = unowned(async {}, NoopSchedule); + task.run(); +} + +#[test] +fn schedule() { + with(|rt| { + rt.spawn(async { + crate::task::yield_now().await; + }); + + assert_eq!(2, rt.tick()); + rt.shutdown(); + }) +} + +#[test] +fn shutdown() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + + rt.shutdown(); + }) +} + +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + +fn with(f: impl FnOnce(Runtime)) { + struct Reset; + + impl Drop for Reset { + fn drop(&mut self) { + let _rt = CURRENT.try_lock().unwrap().take(); + } + } + + let _reset = Reset; + + let rt = Runtime(Arc::new(Inner { + owned: OwnedTasks::new(), + core: TryLock::new(Core { + queue: VecDeque::new(), + }), + })); + + *CURRENT.try_lock().unwrap() = Some(rt.clone()); + f(rt) +} + +#[derive(Clone)] +struct Runtime(Arc<Inner>); + +struct Inner { + core: TryLock<Core>, + owned: OwnedTasks<Runtime>, +} + +struct Core { + queue: VecDeque<task::Notified<Runtime>>, +} + +static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); + +impl Runtime { + fn spawn<T>(&self, future: T) -> JoinHandle<T::Output> + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + + fn tick(&self) -> usize { + self.tick_max(usize::MAX) + } + + fn tick_max(&self, max: usize) -> usize { + let mut n = 0; + + while !self.is_empty() && n < max { + let task = self.next_task(); + n += 1; + let task = self.0.owned.assert_owner(task); + task.run(); + } + + n + } + + fn is_empty(&self) -> bool { + self.0.core.try_lock().unwrap().queue.is_empty() + } + + fn next_task(&self) -> task::Notified<Runtime> { + self.0.core.try_lock().unwrap().queue.pop_front().unwrap() + } + + fn shutdown(&self) { + let mut core = self.0.core.try_lock().unwrap(); + + self.0.owned.close_and_shutdown_all(); + + while let Some(task) = core.queue.pop_back() { + drop(task); + } + + drop(core); + + assert!(self.0.owned.is_empty()); + } +} + +impl Schedule for Runtime { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + self.0.owned.remove(task) + } + + fn schedule(&self, task: task::Notified<Self>) { + self.0.core.try_lock().unwrap().queue.push_back(task); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task_combinations.rs b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs new file mode 100644 index 0000000000..76ce2330c2 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option<crate::task::LocalSet>, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on<T>(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn<T>(&self, task: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option<oneshot::Sender<()>>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper<F> { + inner: F, + on_drop: Option<oneshot::Sender<()>>, + panic_on_drop: bool, + } + impl<F: Future> Future for FutWrapper<F> { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl<F> Drop for FutWrapper<F> { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option<oneshot::Sender<()>>, + wait_complete: Option<oneshot::Receiver<()>>, + on_output_drop: Option<oneshot::Sender<()>>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} |