diff options
Diffstat (limited to 'third_party/rust/tokio/tests/task_local_set.rs')
-rw-r--r-- | third_party/rust/tokio/tests/task_local_set.rs | 525 |
1 files changed, 525 insertions, 0 deletions
diff --git a/third_party/rust/tokio/tests/task_local_set.rs b/third_party/rust/tokio/tests/task_local_set.rs new file mode 100644 index 0000000000..f8a35d0ede --- /dev/null +++ b/third_party/rust/tokio/tests/task_local_set.rs @@ -0,0 +1,525 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use futures::{ + future::{pending, ready}, + FutureExt, +}; + +use tokio::runtime::{self, Runtime}; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::{self, LocalSet}; +use tokio::time; + +use std::cell::Cell; +use std::sync::atomic::Ordering::{self, SeqCst}; +use std::sync::atomic::{AtomicBool, AtomicUsize}; +use std::time::Duration; + +#[tokio::test(flavor = "current_thread")] +async fn local_basic_scheduler() { + LocalSet::new() + .run_until(async { + task::spawn_local(async {}).await.unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn local_threadpool() { + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + LocalSet::new() + .run_until(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }) + .await + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn localset_future_threadpool() { + thread_local! { + static ON_LOCAL_THREAD: Cell<bool> = Cell::new(false); + } + + ON_LOCAL_THREAD.with(|cell| cell.set(true)); + + let local = LocalSet::new(); + local.spawn_local(async move { + assert!(ON_LOCAL_THREAD.with(|cell| cell.get())); + }); + local.await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn localset_future_timers() { + static RAN1: AtomicBool = AtomicBool::new(false); + static RAN2: AtomicBool = AtomicBool::new(false); + + let local = LocalSet::new(); + local.spawn_local(async move { + time::sleep(Duration::from_millis(5)).await; + RAN1.store(true, Ordering::SeqCst); + }); + local.spawn_local(async move { + time::sleep(Duration::from_millis(10)).await; + RAN2.store(true, Ordering::SeqCst); + }); + local.await; + assert!(RAN1.load(Ordering::SeqCst)); + assert!(RAN2.load(Ordering::SeqCst)); +} + +#[tokio::test] +async fn localset_future_drives_all_local_futs() { + static RAN1: AtomicBool = AtomicBool::new(false); + static RAN2: AtomicBool = AtomicBool::new(false); + static RAN3: AtomicBool = AtomicBool::new(false); + + let local = LocalSet::new(); + local.spawn_local(async move { + task::spawn_local(async { + task::yield_now().await; + RAN3.store(true, Ordering::SeqCst); + }); + task::yield_now().await; + RAN1.store(true, Ordering::SeqCst); + }); + local.spawn_local(async move { + task::yield_now().await; + RAN2.store(true, Ordering::SeqCst); + }); + local.await; + assert!(RAN1.load(Ordering::SeqCst)); + assert!(RAN2.load(Ordering::SeqCst)); + assert!(RAN3.load(Ordering::SeqCst)); +} + +#[tokio::test(flavor = "multi_thread")] +async fn local_threadpool_timer() { + // This test ensures that runtime services like the timer are properly + // set for the local task set. + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + LocalSet::new() + .run_until(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + let join = task::spawn_local(async move { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + time::sleep(Duration::from_millis(10)).await; + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }); + join.await.unwrap(); + }) + .await; +} + +#[test] +// This will panic, since the thread that calls `block_on` cannot use +// in-place blocking inside of `block_on`. +#[should_panic] +fn local_threadpool_blocking_in_place() { + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + LocalSet::new().block_on(&rt, async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + let join = task::spawn_local(async move { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::block_in_place(|| {}); + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }); + join.await.unwrap(); + }); +} + +#[tokio::test(flavor = "multi_thread")] +async fn local_threadpool_blocking_run() { + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + LocalSet::new() + .run_until(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + let join = task::spawn_local(async move { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_blocking(|| { + assert!( + !ON_RT_THREAD.with(|cell| cell.get()), + "blocking must not run on the local task set's thread" + ); + }) + .await + .unwrap(); + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }); + join.await.unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn all_spawns_are_local() { + use futures::future; + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + LocalSet::new() + .run_until(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + let handles = (0..128) + .map(|_| { + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }) + }) + .collect::<Vec<_>>(); + for joined in future::join_all(handles).await { + joined.unwrap(); + } + }) + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn nested_spawn_is_local() { + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + LocalSet::new() + .run_until(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + task::spawn_local(async { + assert!(ON_RT_THREAD.with(|cell| cell.get())); + }) + .await + .unwrap(); + }) + .await + .unwrap(); + }) + .await + .unwrap(); + }) + .await + .unwrap(); + }) + .await; +} + +#[test] +fn join_local_future_elsewhere() { + thread_local! { + static ON_RT_THREAD: Cell<bool> = Cell::new(false); + } + + ON_RT_THREAD.with(|cell| cell.set(true)); + + let rt = runtime::Runtime::new().unwrap(); + let local = LocalSet::new(); + local.block_on(&rt, async move { + let (tx, rx) = oneshot::channel(); + let join = task::spawn_local(async move { + println!("hello world running..."); + assert!( + ON_RT_THREAD.with(|cell| cell.get()), + "local task must run on local thread, no matter where it is awaited" + ); + rx.await.unwrap(); + + println!("hello world task done"); + "hello world" + }); + let join2 = task::spawn(async move { + assert!( + !ON_RT_THREAD.with(|cell| cell.get()), + "spawned task should be on a worker" + ); + + tx.send(()).expect("task shouldn't have ended yet"); + println!("waking up hello world..."); + + join.await.expect("task should complete successfully"); + + println!("hello world task joined"); + }); + join2.await.unwrap() + }); +} + +#[test] +fn drop_cancels_tasks() { + use std::rc::Rc; + + // This test reproduces issue #1842 + let rt = rt(); + let rc1 = Rc::new(()); + let rc2 = rc1.clone(); + + let (started_tx, started_rx) = oneshot::channel(); + + let local = LocalSet::new(); + local.spawn_local(async move { + // Move this in + let _rc2 = rc2; + + started_tx.send(()).unwrap(); + futures::future::pending::<()>().await; + }); + + local.block_on(&rt, async { + started_rx.await.unwrap(); + }); + drop(local); + drop(rt); + + assert_eq!(1, Rc::strong_count(&rc1)); +} + +/// Runs a test function in a separate thread, and panics if the test does not +/// complete within the specified timeout, or if the test function panics. +/// +/// This is intended for running tests whose failure mode is a hang or infinite +/// loop that cannot be detected otherwise. +fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) { + use std::sync::mpsc::RecvTimeoutError; + + let (done_tx, done_rx) = std::sync::mpsc::channel(); + let thread = std::thread::spawn(move || { + f(); + + // Send a message on the channel so that the test thread can + // determine if we have entered an infinite loop: + done_tx.send(()).unwrap(); + }); + + // Since the failure mode of this test is an infinite loop, rather than + // something we can easily make assertions about, we'll run it in a + // thread. When the test thread finishes, it will send a message on a + // channel to this thread. We'll wait for that message with a fairly + // generous timeout, and if we don't receive it, we assume the test + // thread has hung. + // + // Note that it should definitely complete in under a minute, but just + // in case CI is slow, we'll give it a long timeout. + match done_rx.recv_timeout(timeout) { + Err(RecvTimeoutError::Timeout) => panic!( + "test did not complete within {:?} seconds, \ + we have (probably) entered an infinite loop!", + timeout, + ), + // Did the test thread panic? We'll find out for sure when we `join` + // with it. + Err(RecvTimeoutError::Disconnected) => { + println!("done_rx dropped, did the test thread panic?"); + } + // Test completed successfully! + Ok(()) => {} + } + + thread.join().expect("test thread should not panic!") +} + +#[test] +fn drop_cancels_remote_tasks() { + // This test reproduces issue #1885. + with_timeout(Duration::from_secs(60), || { + let (tx, mut rx) = mpsc::channel::<()>(1024); + + let rt = rt(); + + let local = LocalSet::new(); + local.spawn_local(async move { while rx.recv().await.is_some() {} }); + local.block_on(&rt, async { + time::sleep(Duration::from_millis(1)).await; + }); + + drop(tx); + + // This enters an infinite loop if the remote notified tasks are not + // properly cancelled. + drop(local); + }); +} + +#[test] +fn local_tasks_wake_join_all() { + // This test reproduces issue #2460. + with_timeout(Duration::from_secs(60), || { + use futures::future::join_all; + use tokio::task::LocalSet; + + let rt = rt(); + let set = LocalSet::new(); + let mut handles = Vec::new(); + + for _ in 1..=128 { + handles.push(set.spawn_local(async move { + tokio::task::spawn_local(async move {}).await.unwrap(); + })); + } + + rt.block_on(set.run_until(join_all(handles))); + }); +} + +#[test] +fn local_tasks_are_polled_after_tick() { + // This test depends on timing, so we run it up to five times. + for _ in 0..4 { + let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner); + if res.is_ok() { + // success + return; + } + } + + // Test failed 4 times. Try one more time without catching panics. If it + // fails again, the test fails. + local_tasks_are_polled_after_tick_inner(); +} + +#[tokio::main(flavor = "current_thread")] +async fn local_tasks_are_polled_after_tick_inner() { + // Reproduces issues #1899 and #1900 + + static RX1: AtomicUsize = AtomicUsize::new(0); + static RX2: AtomicUsize = AtomicUsize::new(0); + const EXPECTED: usize = 500; + + RX1.store(0, SeqCst); + RX2.store(0, SeqCst); + + let (tx, mut rx) = mpsc::unbounded_channel(); + + let local = LocalSet::new(); + + local + .run_until(async { + let task2 = task::spawn(async move { + // Wait a bit + time::sleep(Duration::from_millis(10)).await; + + let mut oneshots = Vec::with_capacity(EXPECTED); + + // Send values + for _ in 0..EXPECTED { + let (oneshot_tx, oneshot_rx) = oneshot::channel(); + oneshots.push(oneshot_tx); + tx.send(oneshot_rx).unwrap(); + } + + time::sleep(Duration::from_millis(10)).await; + + for tx in oneshots.drain(..) { + tx.send(()).unwrap(); + } + + time::sleep(Duration::from_millis(20)).await; + let rx1 = RX1.load(SeqCst); + let rx2 = RX2.load(SeqCst); + println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2); + assert_eq!(EXPECTED, rx1); + assert_eq!(EXPECTED, rx2); + }); + + while let Some(oneshot) = rx.recv().await { + RX1.fetch_add(1, SeqCst); + + task::spawn_local(async move { + oneshot.await.unwrap(); + RX2.fetch_add(1, SeqCst); + }); + } + + task2.await.unwrap(); + }) + .await; +} + +#[tokio::test] +async fn acquire_mutex_in_drop() { + use futures::future::pending; + + let (tx1, rx1) = oneshot::channel(); + let (tx2, rx2) = oneshot::channel(); + let local = LocalSet::new(); + + local.spawn_local(async move { + let _ = rx2.await; + unreachable!(); + }); + + local.spawn_local(async move { + let _ = rx1.await; + tx2.send(()).unwrap(); + unreachable!(); + }); + + // Spawn a task that will never notify + local.spawn_local(async move { + pending::<()>().await; + tx1.send(()).unwrap(); + }); + + // Tick the loop + local + .run_until(async { + task::yield_now().await; + }) + .await; + + // Drop the LocalSet + drop(local); +} + +#[tokio::test] +async fn spawn_wakes_localset() { + let local = LocalSet::new(); + futures::select! { + _ = local.run_until(pending::<()>()).fuse() => unreachable!(), + ret = async { local.spawn_local(ready(())).await.unwrap()}.fuse() => ret + } +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +} |