summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tokio/src/runtime/enter.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/tokio/src/runtime/enter.rs')
-rw-r--r--third_party/rust/tokio/src/runtime/enter.rs205
1 files changed, 205 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/runtime/enter.rs b/third_party/rust/tokio/src/runtime/enter.rs
new file mode 100644
index 0000000000..3f14cb5878
--- /dev/null
+++ b/third_party/rust/tokio/src/runtime/enter.rs
@@ -0,0 +1,205 @@
+use std::cell::{Cell, RefCell};
+use std::fmt;
+use std::marker::PhantomData;
+
+#[derive(Debug, Clone, Copy)]
+pub(crate) enum EnterContext {
+ #[cfg_attr(not(feature = "rt"), allow(dead_code))]
+ Entered {
+ allow_blocking: bool,
+ },
+ NotEntered,
+}
+
+impl EnterContext {
+ pub(crate) fn is_entered(self) -> bool {
+ matches!(self, EnterContext::Entered { .. })
+ }
+}
+
+thread_local!(static ENTERED: Cell<EnterContext> = Cell::new(EnterContext::NotEntered));
+
+/// Represents an executor context.
+pub(crate) struct Enter {
+ _p: PhantomData<RefCell<()>>,
+}
+
+cfg_rt! {
+ use crate::park::thread::ParkError;
+
+ use std::time::Duration;
+
+ /// Marks the current thread as being within the dynamic extent of an
+ /// executor.
+ pub(crate) fn enter(allow_blocking: bool) -> Enter {
+ if let Some(enter) = try_enter(allow_blocking) {
+ return enter;
+ }
+
+ panic!(
+ "Cannot start a runtime from within a runtime. This happens \
+ because a function (like `block_on`) attempted to block the \
+ current thread while the thread is being used to drive \
+ asynchronous tasks."
+ );
+ }
+
+ /// Tries to enter a runtime context, returns `None` if already in a runtime
+ /// context.
+ pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> {
+ ENTERED.with(|c| {
+ if c.get().is_entered() {
+ None
+ } else {
+ c.set(EnterContext::Entered { allow_blocking });
+ Some(Enter { _p: PhantomData })
+ }
+ })
+ }
+}
+
+// Forces the current "entered" state to be cleared while the closure
+// is executed.
+//
+// # Warning
+//
+// This is hidden for a reason. Do not use without fully understanding
+// executors. Misusing can easily cause your program to deadlock.
+cfg_rt_multi_thread! {
+ pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
+ // Reset in case the closure panics
+ struct Reset(EnterContext);
+ impl Drop for Reset {
+ fn drop(&mut self) {
+ ENTERED.with(|c| {
+ assert!(!c.get().is_entered(), "closure claimed permanent executor");
+ c.set(self.0);
+ });
+ }
+ }
+
+ let was = ENTERED.with(|c| {
+ let e = c.get();
+ assert!(e.is_entered(), "asked to exit when not entered");
+ c.set(EnterContext::NotEntered);
+ e
+ });
+
+ let _reset = Reset(was);
+ // dropping _reset after f() will reset ENTERED
+ f()
+ }
+}
+
+cfg_rt! {
+ /// Disallows blocking in the current runtime context until the guard is dropped.
+ pub(crate) fn disallow_blocking() -> DisallowBlockingGuard {
+ let reset = ENTERED.with(|c| {
+ if let EnterContext::Entered {
+ allow_blocking: true,
+ } = c.get()
+ {
+ c.set(EnterContext::Entered {
+ allow_blocking: false,
+ });
+ true
+ } else {
+ false
+ }
+ });
+ DisallowBlockingGuard(reset)
+ }
+
+ pub(crate) struct DisallowBlockingGuard(bool);
+ impl Drop for DisallowBlockingGuard {
+ fn drop(&mut self) {
+ if self.0 {
+ // XXX: Do we want some kind of assertion here, or is "best effort" okay?
+ ENTERED.with(|c| {
+ if let EnterContext::Entered {
+ allow_blocking: false,
+ } = c.get()
+ {
+ c.set(EnterContext::Entered {
+ allow_blocking: true,
+ });
+ }
+ })
+ }
+ }
+ }
+}
+
+cfg_rt_multi_thread! {
+ /// Returns true if in a runtime context.
+ pub(crate) fn context() -> EnterContext {
+ ENTERED.with(|c| c.get())
+ }
+}
+
+cfg_rt! {
+ impl Enter {
+ /// Blocks the thread on the specified future, returning the value with
+ /// which that future completes.
+ pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError>
+ where
+ F: std::future::Future,
+ {
+ use crate::park::thread::CachedParkThread;
+
+ let mut park = CachedParkThread::new();
+ park.block_on(f)
+ }
+
+ /// Blocks the thread on the specified future for **at most** `timeout`
+ ///
+ /// If the future completes before `timeout`, the result is returned. If
+ /// `timeout` elapses, then `Err` is returned.
+ pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ParkError>
+ where
+ F: std::future::Future,
+ {
+ use crate::park::Park;
+ use crate::park::thread::CachedParkThread;
+ use std::task::Context;
+ use std::task::Poll::Ready;
+ use std::time::Instant;
+
+ let mut park = CachedParkThread::new();
+ let waker = park.get_unpark()?.into_waker();
+ let mut cx = Context::from_waker(&waker);
+
+ pin!(f);
+ let when = Instant::now() + timeout;
+
+ loop {
+ if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
+ return Ok(v);
+ }
+
+ let now = Instant::now();
+
+ if now >= when {
+ return Err(());
+ }
+
+ park.park_timeout(when - now)?;
+ }
+ }
+ }
+}
+
+impl fmt::Debug for Enter {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Enter").finish()
+ }
+}
+
+impl Drop for Enter {
+ fn drop(&mut self) {
+ ENTERED.with(|c| {
+ assert!(c.get().is_entered());
+ c.set(EnterContext::NotEntered);
+ });
+ }
+}