//! **J**oin **O**n **D**rop thread (`jod_thread`) is a thin wrapper around `std::thread`, //! which makes sure that by default all threads are joined. use std::fmt; /// Like `thread::JoinHandle`, but joins the thread on drop and propagates /// panics by default. pub struct JoinHandle(Option>); impl fmt::Debug for JoinHandle { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.pad("JoinHandle { .. }") } } impl Drop for JoinHandle { fn drop(&mut self) { if let Some(inner) = self.0.take() { let res = inner.join(); if res.is_err() && !std::thread::panicking() { res.unwrap(); } } } } impl JoinHandle { pub fn thread(&self) -> &std::thread::Thread { self.0.as_ref().unwrap().thread() } pub fn join(mut self) -> T { let inner = self.0.take().unwrap(); inner.join().unwrap() } pub fn detach(mut self) -> std::thread::JoinHandle { let inner = self.0.take().unwrap(); inner } } impl From> for JoinHandle { fn from(inner: std::thread::JoinHandle) -> JoinHandle { JoinHandle(Some(inner)) } } #[derive(Debug)] pub struct Builder(std::thread::Builder); impl Builder { pub fn new() -> Builder { Builder(std::thread::Builder::new()) } pub fn name(self, name: String) -> Builder { Builder(self.0.name(name)) } pub fn stack_size(self, size: usize) -> Builder { Builder(self.0.stack_size(size)) } pub fn spawn(self, f: F) -> std::io::Result> where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static, { self.0.spawn(f).map(JoinHandle::from) } } pub fn spawn(f: F) -> JoinHandle where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static, { Builder::new().spawn(f).expect("failed to spawn thread") } #[test] fn smoke() { use std::sync::atomic::{AtomicU32, Ordering}; static COUNTER: AtomicU32 = AtomicU32::new(0); drop(spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst))); assert_eq!(COUNTER.load(Ordering::SeqCst), 1); let res = std::panic::catch_unwind(|| { let _handle = Builder::new() .name("panicky".to_string()) .spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst)) .unwrap(); panic!("boom") }); assert!(res.is_err()); assert_eq!(COUNTER.load(Ordering::SeqCst), 2); let res = std::panic::catch_unwind(|| { let handle = spawn(|| panic!("boom")); let () = handle.join(); }); assert!(res.is_err()); let res = std::panic::catch_unwind(|| { let _handle = spawn(|| panic!("boom")); }); assert!(res.is_err()); }